diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e5ff4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1,425 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml + +*.o +*.a +.cache/ +.vs/ +.vscode/ +.DS_Store + +.build/ +build/ +build-debug/ + +_deps/ +*.cmake +compile_commands.json +CMakeFiles/ +CMakeCache.txt + +models/* + +.envrc +.direnv/ + +.venv +__pycache__ +.idea +Makefile diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..386abad --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,391 @@ +cmake_minimum_required(VERSION 3.3) +project(minigpt4.cpp C CXX) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +include(FetchContent) + +# General +option(MINIGPT4_BUILD_WITH_OPENCV "minigpt4: build opencv (loading and encoding in c++)" OFF) +option(MINIGPT4_BUILD_EXAMPLES "minigpt4: build examples" OFF) +option(MINIGPT4_BUILD_SHARED_LIBRARY "minigpt4: build as a shared library" ON) +option(MINIGPT4_STATIC "minigpt4: static link libraries" OFF) +option(MINIGPT4_NATIVE "minigpt4: enable -march=native flag" OFF) +option(MINIGPT4_LTO "minigpt4: enable link time optimization" OFF) + +# Debug +option(MINIGPT4_ALL_WARNINGS "minigpt4: enable all compiler warnings" ON) +option(MINIGPT4_GPROF "minigpt4: enable gprof" OFF) + +# Sanitizers +option(MINIGPT4_SANITIZE_THREAD "minigpt4: enable thread sanitizer" OFF) +option(MINIGPT4_SANITIZE_ADDRESS "minigpt4: enable address sanitizer" OFF) +option(MINIGPT4_SANITIZE_UNDEFINED "minigpt4: enable undefined sanitizer" OFF) + +# Instruction set specific +option(MINIGPT4_AVX "minigpt4: enable AVX" ON) +option(MINIGPT4_AVX2 "minigpt4: enable AVX2" ON) +option(MINIGPT4_AVX512 "minigpt4: enable AVX512" OFF) +option(MINIGPT4_FMA "minigpt4: enable FMA" ON) + +# 3rd party libs +option(MINIGPT4_ACCELERATE "minigpt4: enable Accelerate framework" ON) +option(MINIGPT4_OPENBLAS "minigpt4: use OpenBLAS" OFF) +option(MINIGPT4_CUBLAS "minigpt4: use cuBLAS" OFF) + +# Build only shared library without building tests and extras +option(MINIGPT4_STANDALONE "minigpt4: build only MINIGPT4 library" OFF) + +# +# Compile flags +# + +set(CMAKE_C_FLAGS_DEBUG "-g -DDEBUG") +set(CMAKE_CXX_FLAGS_DEBUG "-g -DDEBUG") + +set(CMAKE_CXX_STANDARD 23) +set(CMAKE_CXX_STANDARD_REQUIRED true) +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) + +if (NOT MSVC) + if (MINIGPT4_SANITIZE_THREAD) + add_compile_options(-fsanitize=thread) + link_libraries(-fsanitize=thread) + endif() + + if (MINIGPT4_SANITIZE_ADDRESS) + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries(-fsanitize=address) + endif() + + if (MINIGPT4_SANITIZE_UNDEFINED) + add_compile_options(-fsanitize=undefined) + link_libraries(-fsanitize=undefined) + endif() +endif() + +if (APPLE AND MINIGPT4_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + add_compile_definitions(GGML_USE_ACCELERATE) + set(MINIGPT4_EXTRA_LIBS ${MINIGPT4_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() +endif() + +if (MINIGPT4_OPENBLAS) + if (MINIGPT4_STATIC) + set(BLA_STATIC ON) + endif() + + set(BLA_VENDOR OpenBLAS) + find_package(BLAS) + if (BLAS_FOUND) + message(STATUS "OpenBLAS found") + + add_compile_definitions(GGML_USE_OPENBLAS) + add_link_options(${BLAS_LIBRARIES}) + else() + message(WARNING "OpenBLAS not found") + endif() +endif() + +if (MINIGPT4_CUBLAS) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + enable_language(CUDA) + + set(GGML_CUDA_SOURCES ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h) + + add_compile_definitions(GGML_USE_CUBLAS) + + if (MINIGPT4_STATIC) + set(MINIGPT4_EXTRA_LIBS ${MINIGPT4_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + set(MINIGPT4_EXTRA_LIBS ${MINIGPT4_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + else() + message(WARNING "cuBLAS not found") + endif() +endif() + +if (MINIGPT4_ALL_WARNINGS) + if (NOT MSVC) + set(c_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wdouble-promotion + -Wshadow + -Wstrict-prototypes + -Wpointer-arith + -Wno-unused-function + ) + set(cxx_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wno-unused-function + -Wno-multichar + ) + else() + set(c_flags + -W4 + ) + set(cxx_flags + -W4 + ) + endif() + + add_compile_options( + "$<$:${c_flags}>" + "$<$:${cxx_flags}>" + ) + +endif() + +if (MINIGPT4_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +if (NOT MSVC) + if (MINIGPT4_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (MINIGPT4_GPROF) + add_compile_options(-pg) + endif() + if (MINIGPT4_NATIVE) + add_compile_options(-march=native) + endif() +endif() + +if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + message(STATUS "ARM detected") + if (MSVC) + # TODO: arm msvc? + else() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + add_compile_options(-mcpu=native) + endif() + # TODO: armv6,7,8 version specific flags + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") + message(STATUS "x86 detected") + if (MSVC) + if (MINIGPT4_AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + elseif (MINIGPT4_AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) + elseif (MINIGPT4_AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) + endif() + else() + add_compile_options(-mf16c) + if (MINIGPT4_FMA) + add_compile_options(-mfma) + endif() + if (MINIGPT4_AVX) + add_compile_options(-mavx) + endif() + if (MINIGPT4_AVX2) + add_compile_options(-mavx2) + endif() + if (MINIGPT4_AVX512) + add_compile_options(-mavx512f) + add_compile_options(-mavx512bw) + endif() + endif() +else() + # TODO: support PowerPC + message(STATUS "Unknown architecture") +endif() + +# +# Build libraries +# + +if (MSVC) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) +endif() + +if (MINIGPT4_BUILD_SHARED_LIBRARY) + set(MINIGPT4_LIBRARY_BUILD SHARED) +else() + set(MINIGPT4_LIBRARY_BUILD STATIC) +endif() + +macro(add_dependency) + SET(dependency_name ${ARGV0}) + SET(endpoint_url ${ARGV1}) + SET(endpoint_tag ${ARGV2}) + SET(do_build_with_cmake ${ARGV3}) + + FetchContent_Declare( + ${dependency_name} + GIT_REPOSITORY ${endpoint_url} + GIT_TAG ${endpoint_tag} + ) + + FetchContent_GetProperties(${dependency_name}) + + if (NOT ${dependency_name}_POPULATED) + FetchContent_Populate(${dependency_name}) + message(STATUS "Working on ${dependency_name}") + + if (${do_build_with_cmake}) + add_subdirectory(${${dependency_name}_SOURCE_DIR} ${${dependency_name}_BINARY_DIR}) + else () + message("\tHeader only") + endif () + endif () +endmacro() + +set(MINIGPT4_MSVC_USE_STATIC_CRT on CACHE BOOL "Use MT flags when compiling in MSVC") +if (MSVC) + if (MINIGPT4_MSVC_USE_STATIC_CRT) + message("-- Using static CRT linking ${MINIGPT4_MSVC_USE_STATIC_CRT}") + foreach(flag_var CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO) + string(REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endforeach() + endif() +endif() + +if (MINIGPT4_BUILD_SHARED_LIBRARY) + # hack... + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + # set_property(TARGET fmt PROPERTY POSITION_INDEPENDENT_CODE ON) +endif() + +#add_dependency(ggml https://github.com/ggerganov/ggml 93b94a2d41e880cb2abfb708535d5b04ad05b7a5 TRUE) +add_dependency(fmt https://github.com/fmtlib/fmt 9.1.0 TRUE) +add_dependency(unordered_dense https://github.com/martinus/unordered_dense v4.0.0 TRUE) +add_dependency(stb https://github.com/nothings/stb 5736b15 FALSE) +add_dependency(spdlog https://github.com/gabime/spdlog v1.11.0 TRUE) +add_dependency(nlohmann_json https://github.com/nlohmann/json v3.11.2 TRUE) + +set(EXPECTED_BUILD_TESTS OFF) +add_dependency(tl_expected https://github.com/TartanLlama/expected v1.1.0 TRUE) + +set(LLAMA_STATIC ${MINIGPT4_STATIC}) +set(LLAMA_NATIVE ${MINIGPT4_NATIVE}) +set(LLAMA_LTO ${MINIGPT4_LTO}) +set(LLAMA_AVX ${MINIGPT4_AVX}) +set(LLAMA_AVX2 ${MINIGPT4_AVX2}) +set(LLAMA_AVX512 ${MINIGPT4_AVX512}) +set(LLAMA_AVX512_VBMI ${MINIGPT4_AVX512_VBMI}) +set(LLAMA_AVX512_VNNI ${MINIGPT4_AVX512_VNNI}) +set(LLAMA_FMA ${MINIGPT4_FMA}) +set(LLAMA_ACCELERATE ${MINIGPT4_ACCELERATE}) +set(GGML_USE_K_QUANTS ON) +add_dependency(llama_cpp https://github.com/ggerganov/llama.cpp master-31cfbb1 TRUE) + +set(OPENCV_INCLUDE_DIRS "") +set(OPENCV_LIBS "") +set(PILLOW_RESIZE_INCLUDE_DIRS "") +set(PILLOW_RESIZE_LIBS "") + +if (MINIGPT4_BUILD_WITH_OPENCV) + find_package(OpenCV REQUIRED) + set(OPENCV_INCLUDE_DIRS ${OpenCV_INCLUDE_DIRS}) + set(OPENCV_LIBS ${OpenCV_LIBS}) + + add_dependency(pillow_resize https://github.com/zurutech/pillow-resize 4427c50 TRUE) + + set(PILLOW_RESIZE_INCLUDE_DIRS ${pillow_resize_SOURCE_DIR}/include/PillowResize) + set(PILLOW_RESIZE_LIBS PillowResize) + add_compile_definitions(MINIGPT4_BUILD_WITH_OPENCV) +else() + add_dependency(magic_enum https://github.com/Neargye/magic_enum v0.9.3 TRUE) +endif() + +add_library(minigpt4 ${MINIGPT4_LIBRARY_BUILD} + minigpt4.cpp + minigpt4.h) + +target_include_directories(minigpt4 PUBLIC + . + + ${fmt_SOURCE_DIR} +# ${ggml_SOURCE_DIR} + ${unordered_dense_SOURCE_DIR} + ${stb_SOURCE_DIR} + ${spdlog_SOURCE_DIR} + ${nlohmann_json_SOURCE_DIR} + ${tokenizers_cpp_SOURCE_DIR} + ${llama_cpp_SOURCE_DIR} + ${magic_enum_SOURCE_DIR} + ${tl_expected_SOURCE_DIR}/include/tl + + ${OPENCV_INCLUDE_DIRS} + ${PILLOW_RESIZE_INCLUDE_DIRS} +) + +target_link_libraries(minigpt4 PUBLIC + fmt +# ggml + unordered_dense + spdlog + nlohmann_json + llama + magic_enum + expected + + ${OPENCV_LIBS} + ${PILLOW_RESIZE_LIBS} +) + +target_link_libraries(minigpt4 PRIVATE ${CLIP23_EXTRA_LIBS}) + +if (MSVC) + if (CMAKE_BUILD_TYPE EQUAL "DEBUG") + target_compile_options(minigpt4 PUBLIC "/ZI") + target_link_options(minigpt4 PUBLIC "/INCREMENTAL") + endif() +endif() + +if (MINIGPT4_BUILD_SHARED_LIBRARY) + set_target_properties(minigpt4 PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(minigpt4 PRIVATE MINIGPT4_SHARED MINIGPT4_BUILD) +endif() + +if (MINIGPT4_BUILD_EXAMPLES) + add_subdirectory(examples) +endif() \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..99bfc08 --- /dev/null +++ b/README.md @@ -0,0 +1,141 @@ +# minigpt4.cpp + + + +Inference of [MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4) in pure C/C++. + +## Description + +The main goal of `minigpt4.cpp` is to run minigpt4 using 4-bit quantization with using the [ggml](https://github.com/ggerganov/ggml) library. + +## Demo + +![minigpt1](assets/webui_demo.png) + +![minigpt1](assets/minigpt4-demo1.gif) + +## Usage + +### 1. Clone repo + +**Requirements**: [git](https://gitforwindows.org/) + +```bash +git clone --recursive https://github.com/Maknee/minigpt4.cpp +cd minigpt4.cpp +``` + +### 2. Getting the library + +#### Option 1: Download precompiled binary + +##### Windows / Linux / MacOS + +Go to [Releases](https://github.com/Maknee/minigpt4.cpp/releases) and extract `minigpt4` library file into the repository directory. + +#### Option 2: Build library manually + +##### Windows + +**Requirements**: [CMake](https://cmake.org/download/), [Visual Studio](https://visualstudio.microsoft.com/) and [Git](https://gitforwindows.org/) + +```commandline +cmake . +cmake --build . --config Release +``` + +`bin\Release\minigpt4.dll` should be generated + +##### Linux + +**Requirements**: CMake (Ubuntu: `sudo apt install cmake`) + +```bash +cmake . +cmake --build . --config Release +``` + +`minigpt4.so` should be generated + +##### MacOS + +**Requirements**: CMake (MacOS: `brew install cmake`) + +```sh +cmake . +cmake --build . --config Release +``` + +`minigpt4.dylib` should be generated + +**Note:** If you build with opencv (allowing features such as loading and preprocessing image within the library itself), set `MINIGPT4_BUILD_WITH_OPENCV` to `ON` in `CMakeLists.txt` or build with `-DMINIGPT4_BUILD_WITH_OPENCV=ON` as a parameter to the cmake cli. + +### 3. Obtaining the model + +#### Option 1: Download pre-quantized MiniGPT4 model + +Pre-quantized models are avaliable on Hugging Face ~ [7B](https://huggingface.co/datasets/maknee/minigpt4-7b-ggml/tree/main) or [13B](https://huggingface.co/datasets/maknee/minigpt4-13b-ggml/tree/main). + +#### Option 2: Convert and quantize PyTorch model + +**Requirements**: [Python 3.x](https://www.python.org/downloads/) and [PyTorch](https://pytorch.org/get-started/locally/). + +Clone the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) repository and perform the setup + +```sh +cd minigpt4 +git clone https://github.com/Vision-CAIR/MiniGPT-4.git +conda env create -f environment.yml +conda activate minigpt4 +``` + +Download the pretrained checkpoint in the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) repository under `Checkpoint Aligned with Vicuna 7B` or `Checkpoint Aligned with Vicuna 13B` or download them from [Huggingface link for 7B](https://huggingface.co/datasets/maknee/minigpt4-7b-ggml/blob/main/pretrained_minigpt4_7b.pth) or [13B](https://huggingface.co/datasets/maknee/minigpt4-13b-ggml/blob/main/pretrained_minigpt4.pth) + +Convert the model weights into ggml format + +##### Windows + +```commandline +cd minigpt4 +python convert.py C:\pretrained_minigpt4.pth --ftype=f16 +``` + +##### Linux / MacOS +```sh +python convert.py ~/Downloads/pretrained_minigpt4.pth --outtype f16 +``` + +`minigpt4-7B-f16.bin` or `minigpt4-13B-f16.bin` should be generated + +#### 4. Obtaining the vicuna model + +#### Option 1: Download pre-quantized vicuna-v0 model + +Pre-quantized models are avaliable on [Hugging Face](https://huggingface.co/datasets/maknee/ggml-vicuna-v0-quantized/tree/main) + +#### Option 2: Convert and quantize vicuna-v0 model + +**Requirements**: [Python 3.x](https://www.python.org/downloads/) and [PyTorch](https://pytorch.org/get-started/locally/). + +Follow the [guide from the MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4/blob/main/PrepareVicuna.md) to obtain the vicuna-v0 model. + +Then, clone llama.cpp + +```sh +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +cmake . +cmake --build . --config Release +``` + +Convert the model to ggml + +```sh +python convert.py +``` + +Quantize the model + +```sh +python quanitize Q4_1 +``` diff --git a/assets/minigpt4-demo1.gif b/assets/minigpt4-demo1.gif new file mode 100644 index 0000000..5dd3d07 Binary files /dev/null and b/assets/minigpt4-demo1.gif differ diff --git a/assets/webui_demo.png b/assets/webui_demo.png new file mode 100644 index 0000000..d0fc1fb Binary files /dev/null and b/assets/webui_demo.png differ diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 0000000..b820f4f --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,9 @@ +add_dependency(argparse https://github.com/p-ranav/argparse 0b51382 TRUE) +add_dependency(spdlog https://github.com/gabime/spdlog v1.11.0 TRUE) + +set(CMAKE_C_FLAGS_DEBUG "-g -DDEBUG") +set(CMAKE_CXX_FLAGS_DEBUG "-g -DDEBUG") + +add_executable(main main.cpp) +target_link_libraries(main PRIVATE minigpt4 ggml argparse spdlog) + diff --git a/examples/main.cpp b/examples/main.cpp new file mode 100644 index 0000000..8ae4792 --- /dev/null +++ b/examples/main.cpp @@ -0,0 +1,302 @@ +#include + +#include "minigpt4.h" +#include +#include +#include + +#define INFO(...) spdlog::info(__VA_ARGS__) +#define ERR(...) \ + spdlog::error(__VA_ARGS__); \ + std::cerr << std::endl; +#define ERR_EXIT(...) \ + ERR(__VA_ARGS__); \ + exit(-1); +#define CHECK_ERR_EXIT(x, ...) \ + if (x) \ + { \ + ERR("ERROR MESSAGE: {}", minigpt4_error_code_to_string(x)); \ + ERR_EXIT(__VA_ARGS__) \ + } + +namespace fs = std::filesystem; + +int main(int argc, char **argv) +{ + spdlog::set_pattern("[%H:%M:%S %z] [%n] [%^---%L---%$] [thread %t] %v"); + spdlog::stopwatch sw; + + argparse::ArgumentParser args("MiniGPT4.cpp", "1.0", argparse::default_arguments::help, false); + + args.add_argument("-v", "--verbose") + .help("increase output verbosity") + .default_value(0) + .scan<'i', int>(); + + args.add_argument("-m", "--model") + .required() + .help("Path to the model file") + .default_value(std::string("minigpt4-13B-f16.bin")); + + args.add_argument("-lm", "--llm_model") + .required() + .help("Path to language model") + .default_value(std::string("ggml-vicuna-13b-v0-q4_1.bin")); + + args.add_argument("-t", "--threads") + .help("Number of threads to use") + .default_value(0) + .scan<'i', int>(); + + args.add_argument("--image") + .required() + .help("Images to encode") + .nargs(argparse::nargs_pattern::at_least_one) + .default_value(std::string{"../minigpt4/images/llama.png"}); + + args.add_argument("--texts") + .required() + .help("Texts to encode") + .nargs(argparse::nargs_pattern::at_least_one) + .default_value(std::vector{"what is the text in the picture?", "what is the color of it?"}); + + args.add_argument("--temp") + .help("temperature") + .default_value(0.80f) + .scan<'f', float>(); + + args.add_argument("--top_k") + .help("top_k") + .default_value(40) + .scan<'i', int>(); + + args.add_argument("--top_p") + .help("top_p") + .default_value(0.90f) + .scan<'f', float>(); + + args.add_argument("--tfs_z") + .help("tfs_z") + .default_value(1.00f) + .scan<'f', float>(); + + args.add_argument("--typical_p") + .help("typical_p") + .default_value(1.00f) + .scan<'f', float>(); + + args.add_argument("--repeat_last_n") + .help("repeat_last_n") + .default_value(64) + .scan<'i', int>(); + + args.add_argument("--repeat_penalty") + .help("repeat_penalty") + .default_value(1.10f) + .scan<'f', float>(); + + args.add_argument("--alpha_presence") + .help("alpha_presence") + .default_value(1.00f) + .scan<'f', float>(); + + args.add_argument("--alpha_frequency") + .help("alpha_frequency") + .default_value(1.00f) + .scan<'f', float>(); + + args.add_argument("--mirostat") + .help("mirostat") + .default_value(0) + .scan<'i', int>(); + + args.add_argument("--mirostat_tau") + .help("mirostat_tau") + .default_value(5.00f) + .scan<'f', float>(); + + args.add_argument("--mirostat_eta") + .help("mirostat_eta") + .default_value(1.00f) + .scan<'f', float>(); + + args.add_argument("--penalize_nl") + .help("penalize_nl") + .default_value(1) + .scan<'i', int>(); + + args.add_argument("--n_ctx") + .help("n_ctx") + .default_value(2048) + .scan<'i', int>(); + + args.add_argument("--n_batch_size") + .help("n_batch_size") + .default_value(512) + .scan<'i', int>(); + + args.add_argument("--seed") + .help("seed") + .default_value(1337) + .scan<'i', int>(); + + args.add_argument("--numa") + .help("numa") + .default_value(0) + .scan<'i', int>(); + + args.parse_args(argc, argv); + + auto model = args.get("model"); + auto llm_model = args.get("llm_model"); + auto verbose = args.get("verbose"); + auto threads = args.get("threads"); + auto texts = args.get>("texts"); + auto image_path = args.get("image"); + auto temp = args.get("temp"); + auto top_k = args.get("top_k"); + auto top_p = args.get("top_p"); + auto tfs_z = args.get("tfs_z"); + auto typical_p = args.get("typical_p"); + auto repeat_last_n = args.get("repeat_last_n"); + auto repeat_penalty = args.get("repeat_penalty"); + auto alpha_presence = args.get("alpha_presence"); + auto alpha_frequency = args.get("alpha_frequency"); + auto mirostat = args.get("mirostat"); + auto mirostat_tau = args.get("mirostat_tau"); + auto mirostat_eta = args.get("mirostat_eta"); + auto penalize_nl = args.get("penalize_nl"); + auto seed = args.get("seed"); + auto n_ctx = args.get("n_ctx"); + auto n_batch_size = args.get("n_batch_size"); + auto numa = args.get("numa"); + + if (threads <= 0) + { + threads = static_cast(std::thread::hardware_concurrency()); + } + + INFO("=== Args ==="); + INFO("Model: {}", model); + INFO("LLM Model: {}", llm_model); + INFO("Verbose: {}", verbose); + INFO("Threads: {}", threads); + INFO("Texts: {}", fmt::join(texts, ", ")); + INFO("Images: {}", image_path); + INFO("============"); + INFO("Running from {}", fs::current_path().string()); + + if (!fs::exists(model)) + { + ERR("Model file '{}' does not exist", model); + return 1; + } + + if (!fs::exists(llm_model)) + { + ERR("LLM Model file '{}' does not exist", llm_model); + return 1; + } + + if (!fs::exists(image_path)) + { + ERR("Image file '{}' does not exist", image_path); + return 1; + } + + auto ctx = minigpt4_model_load(model.c_str(), llm_model.c_str(), verbose, seed, n_ctx, n_batch_size, numa); + if (!ctx) + { + ERR("Failed to load model"); + return 1; + } + + MiniGPT4Image image{}; + { + auto err = minigpt4_image_load_from_file(ctx, image_path.c_str(), &image, 0); + CHECK_ERR_EXIT(err, "Failed to load image for {}", image_path); + } + + MiniGPT4Image preprocessed_image{}; + { + auto err = minigpt4_preprocess_image(ctx, &image, &preprocessed_image, 0); + CHECK_ERR_EXIT(err, "Failed to preprocess image for {}", image_path); + } + + MiniGPT4Embedding image_embedding{}; + { + auto err = minigpt4_encode_image(ctx, &preprocessed_image, &image_embedding, threads); + CHECK_ERR_EXIT(err, "Failed to encode image for {}", image_path); + } + + MiniGPT4Embeddings minigpt4_image_embeddings{ + .embeddings = &image_embedding, + .n_embeddings = 1, + }; + + { + int err = minigpt4_system_prompt(ctx, threads); + CHECK_ERR_EXIT(err, "Failed have system prompt"); + } + + { + const auto &text = texts[0]; + int err = minigpt4_begin_chat_image(ctx, &image_embedding, texts[0].c_str(), threads); + CHECK_ERR_EXIT(err, "Failed to chat image {}", image_path); + const char *token = nullptr; + std::string response; + response.reserve(2048); + + do + { + if (token && !minigpt4_contains_eos_token(token)) + { + std::cout << token << std::flush; + } + int err = minigpt4_end_chat_image(ctx, &token, threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl); + CHECK_ERR_EXIT(err, "Failed to generate chat image"); + response += token; + } while (!minigpt4_is_eos(response.c_str())); + } + + { + if (texts.size() > 1) + { + for (auto i = 1; i < texts.size(); i++) + { + const auto &text = texts[i]; + int err = minigpt4_begin_chat(ctx, text.c_str(), threads); + CHECK_ERR_EXIT(err, "Failed to begin chat"); + const char *token = nullptr; + std::string response; + response.reserve(2048); + + do + { + if (token && !minigpt4_contains_eos_token(token)) + { + std::cout << token << std::flush; + } + int err = minigpt4_end_chat(ctx, &token, threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl); + CHECK_ERR_EXIT(err, "Failed to generate chat"); + response += token; + } while (!minigpt4_is_eos(response.c_str())); + } + } + } + + const auto entire_time = sw.elapsed(); + + minigpt4_free_image(&image); + minigpt4_free_image(&preprocessed_image); + minigpt4_free_embedding(&image_embedding); + minigpt4_free(ctx); + + if (verbose) + { + INFO("MiniGPT4"); + INFO("Entire session time spent: {:10.2f}", entire_time.count() * 1000); + } + + return 0; +} \ No newline at end of file diff --git a/minigpt4.cpp b/minigpt4.cpp new file mode 100644 index 0000000..f8287f1 --- /dev/null +++ b/minigpt4.cpp @@ -0,0 +1,2987 @@ +#include "minigpt4.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llama.h" +#include "ggml.h" + +#include "fmt/core.h" +#include "fmt/ranges.h" +#include "ankerl/unordered_dense.h" + +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +#include +#include +#include + +#include +using json = nlohmann::json; + +#include + +#include + +#ifdef MINIGPT4_BUILD_WITH_OPENCV + #include + #include +#endif + +///////////////////// +/// PLATFORM INCLUDE +///////////////////// + +#ifdef __has_include +#if __has_include() +#include +#if defined(_POSIX_MAPPED_FILES) +#include +#endif +#if defined(_POSIX_MEMLOCK_RANGE) +#include +#endif +#endif +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#include +#endif + +///////////////////// +/// FORWARDS +///////////////////// + +namespace fs = std::filesystem; +using namespace std::chrono_literals; + +template +using HashMap = ankerl::unordered_dense::map; + +constexpr auto PAGE_SIZE = 4096u; + +///////////////////// +/// DEFINITIONS +///////////////////// + +constexpr std::string_view EXPECTED_HEADER = "ggml"; +constexpr auto MB = 1024u * 1024u; +constexpr auto GB = 1024u * MB; +constexpr auto bytes_to_mb = [](auto bytes) +{ return static_cast(bytes) / MB; }; + +enum MiniGPT4Error : int +{ + None, + LoadModelFileHeader, + LoadModelFileVersion, + LoadModelMiniGPT4DataType, + LoadLanguageModel, + OpenImage, + ImageSize, + MmapSupport, + FailedToAddString, + LLamaProjectionEmbeddingInvalidSize, + FailedToAddEmbedding, + EosToken, + Eos, + ImageNot224_244_3, + ImageNotF32, + ImageChannelsExpectedRGB, + ImageFormatExpectedU8, + PathDoesNotExist, + DumpModelFileOpen, + OpenCVNotLinked, +}; + +///////////////////// +/// CONSTANT GLOBALS +///////////////////// + +constexpr std::size_t PATCH_SIZE = 16; + +constexpr std::size_t NUM_ATTENTION_HEADS = 12; +constexpr std::size_t ATTENTION_HEAD_SIZE = 64; +constexpr std::size_t ALL_HEAD_SIZE = 768; + +constexpr std::size_t IMAGE_RESIZE = 224; + +constexpr std::size_t LLAMA_PROJECTION_EMBEDDING_SIZE1 = 32; +constexpr std::size_t LLAMA_PROJECTION_HIDDEN_SIZE_7B = 4096; +constexpr std::size_t LLAMA_PROJECTION_HIDDEN_SIZE_13B = 5120; +constexpr std::size_t LLAMA_PROJECTION_EMBEDDING_SIZE_7B = LLAMA_PROJECTION_HIDDEN_SIZE_7B * LLAMA_PROJECTION_EMBEDDING_SIZE1; +constexpr std::size_t LLAMA_PROJECTION_EMBEDDING_SIZE_13B = LLAMA_PROJECTION_HIDDEN_SIZE_13B * LLAMA_PROJECTION_EMBEDDING_SIZE1; + +constexpr std::string_view SYSTEM_PROMPT = R"(Give the following image: ImageContent. You will be able to see the image once I provide it to you. Please answer my questions.###)"; +constexpr std::string_view EOS_TOKEN_SUFFIX = "##"; +constexpr std::string_view EOS_SUFFIX = "###"; + +constexpr float TORCH_FLOAT_FIFO_MIN = -3.40282e+38; + +constexpr std::size_t RGB_CHANNELS = 3; +constexpr static std::size_t MAX_SCRATCH_BUFFERS = 1; + +///////////////////// +/// MUTABLE GLOBALS +///////////////////// + +static MiniGPT4Verbosity global_verbosity; + +///////////////////// +/// Memory sizes +///////////////////// + +enum class ModelType +{ + Unknown, + Vicuna7B, + Vicuna13B, +}; + +// TODO: dynamically determine sizes +const static HashMap model_type_to_compute_size = { + {ModelType::Vicuna7B, 100 * MB}, + {ModelType::Vicuna13B, 100 * MB}, +}; + +const static HashMap model_type_to_scratch_size = { + {ModelType::Vicuna7B, 2814 * MB}, + {ModelType::Vicuna13B, 2815 * MB}, +}; + +///////////////////// +/// UTILS +///////////////////// + +#define CCAT(a, b) a##b +#define CAT(a, b) CCAT(a, b) + +#define STRINGIFY2(x) #x +#define STRINGIFY(x) STRINGIFY2(x) + +#define UNIQUIFY2(x) CAT(x, __LINE__) +#define UNIQUIFY(x) UNIQUIFY2(x) + +#ifdef USE_PREFIX +#define PREFIX "{}:{}:{} " +#define PREFIX_ENTRIES __FILE__, __FUNCTION__, __LINE__ +#else +#define PREFIX +#define PREFIX_ENTRIES __FILE__ +#endif + +#define DEBUG(...) \ + do \ + { \ + if (global_verbosity >= MiniGPT4Verbosity::MINIGPT4_VERBOSITY_DEBUG) \ + { \ + auto UNIQUIFY(log_header) = fmt::format(PREFIX "DEBUG: ", PREFIX_ENTRIES); \ + auto UNIQUIFY(other_info) = fmt::format(__VA_ARGS__); \ + std::cout << UNIQUIFY(log_header) << UNIQUIFY(other_info) << "\n"; \ + } \ + } while (0) + +#define INFO(...) \ + do \ + { \ + if (global_verbosity >= MiniGPT4Verbosity::MINIGPT4_VERBOSITY_INFO) \ + { \ + auto UNIQUIFY(log_header) = fmt::format(PREFIX "INFO: ", PREFIX_ENTRIES); \ + auto UNIQUIFY(other_info) = fmt::format(__VA_ARGS__); \ + std::cout << UNIQUIFY(log_header) << UNIQUIFY(other_info) << "\n"; \ + } \ + } while (0) + +#define ERR(...) \ + do \ + { \ + if (global_verbosity >= MiniGPT4Verbosity::MINIGPT4_VERBOSITY_ERROR) \ + { \ + auto UNIQUIFY(log_header) = fmt::format(PREFIX "ERROR: ", PREFIX_ENTRIES); \ + auto UNIQUIFY(other_info) = fmt::format(__VA_ARGS__); \ + std::cerr << UNIQUIFY(log_header) << UNIQUIFY(other_info) << "\n"; \ + } \ + } while (0) + +#define PANIC(...) \ + ERR(__VA_ARGS__); \ + exit(-1); + +#ifndef NDEBUG + +#define ASSERT(result, ...) \ + do \ + { \ + if (!(result)) \ + { \ + auto UNIQUIFY(log_header) = fmt::format(PREFIX "ASSERT: [{}] ", PREFIX_ENTRIES, STRINGIFY(result)); \ + auto UNIQUIFY(other_info) = fmt::format(__VA_ARGS__); \ + std::cerr << UNIQUIFY(log_header) << UNIQUIFY(other_info) << "\n"; \ + exit(-1); \ + } \ + } while (0) + +#else +#define ASSERT(result, ...) +#endif + +struct BufferView +{ + explicit BufferView(uint8_t *addr = nullptr, std::size_t size = 0) : addr(addr), size(size) {} + + bool valid() const + { + return addr != nullptr && size != 0; + } + + template + T *As() + { + return reinterpret_cast(addr); + } + + uint8_t *addr{}; + std::size_t size{}; +}; + +struct Buffer : public BufferView +{ + explicit Buffer() = default; + explicit Buffer(std::size_t size_) + { + size = size_; + if (size) + { + buf.resize(size); + addr = buf.data(); + } + } + + std::vector buf{}; +}; + +struct Timer +{ + explicit Timer() {} + double elapsed_us() + { + auto diff = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start).count(); + return diff; + } + + const std::chrono::time_point start = std::chrono::high_resolution_clock::now(); +}; + +struct LoggingTimer : public Timer +{ + explicit LoggingTimer(std::string_view s_ = "") : s(std::string(s_)) {} + ~LoggingTimer() + { + auto diff = elapsed_us(); + if (global_verbosity >= MiniGPT4Verbosity::MINIGPT4_VERBOSITY_INFO) + { + INFO("{} took {} ms to complete", s, diff); + } + } + + std::string s; +}; + +///////////////////// +/// FILE UTILS +///////////////////// + +class MMappedFile +{ +public: + explicit MMappedFile() = default; +#ifdef _POSIX_MAPPED_FILES + static constexpr bool SUPPORTED = true; + void load(fs::path p, bool prefetch = true) + { + fp = std::fopen(p.string().c_str(), "rb"); + ASSERT(fp != nullptr, "file does not exist {}", p.string()); + std::fseek(fp, 0, SEEK_END); + view.size = std::ftell(fp); + std::fseek(fp, 0, SEEK_SET); + + int fd = fileno(fp); + int flags = MAP_SHARED; +#ifdef __linux__ + flags |= MAP_POPULATE; +#endif + view.addr = reinterpret_cast(mmap(NULL, view.size, PROT_READ, flags, fd, 0)); + if (view.addr == MAP_FAILED) + { + ERR("mmap failed: {}", strerror(errno)); + } + + if (prefetch) + { + // Advise the kernel to preload the mapped memory + if (madvise(view.addr, view.size, MADV_WILLNEED)) + { + ERR("warning: madvise(.., MADV_WILLNEED) failed: {}\n", + strerror(errno)); + } + } + } + + ~MMappedFile() + { + fclose(fp); + munmap(view.addr, view.size); + } +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + void load(fs::path p, bool prefetch = true) + { + fp = std::fopen(p.string().c_str(), "rb"); + ASSERT(fp != nullptr, "file does not exist {}", p.string()); + std::fseek(fp, 0, SEEK_END); + view.size = _ftelli64(fp); + std::fseek(fp, 0, SEEK_SET); + + HANDLE hFile = (HANDLE)_get_osfhandle(_fileno(fp)); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + DWORD error = GetLastError(); + + if (hMapping == NULL) + { + PANIC("CreateFileMappingA failed: {}", error); + } + + view.addr = reinterpret_cast(MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0)); + error = GetLastError(); + + if (view.addr == NULL) + { + PANIC("MapViewOfFile failed: {}", error); + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + if (prefetch) + { + // Advise the kernel to preload the mapped memory + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = view.addr; + range.NumberOfBytes = (SIZE_T)view.size; + if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) + { + INFO("PrefetchVirtualMemory failed: {}", GetLastError()); + } + } +#else +#pragma message("warning: You are building for pre-Windows 8; prefetch not supported") +#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8 + CloseHandle(hMapping); + } + + ~MMappedFile() + { + fclose(fp); + if (!UnmapViewOfFile(view.addr)) + { + PANIC("UnmapViewOfFile failed: {}", GetLastError()); + } + } +#else + static constexpr bool SUPPORTED = false; + + void load(fs::path p, bool prefetch = true) + { + PANIC("mmap not supported"); + } +#endif +protected: + BufferView view; + FILE *fp{}; +}; + +class MMapReader : public MMappedFile +{ +public: + explicit MMapReader() = default; + + template + T base_addr() + { + return reinterpret_cast(view.addr); + } + + template + T current_addr() + { + return reinterpret_cast(view.addr + pos); + } + + std::size_t tell() + { + return pos; + } + + void seek(std::size_t new_pos) + { + pos = new_pos; + ASSERT(pos <= view.size, "Out of bounds for seeking {} > {}", pos, view.size); + } + + void seek_to_alignment(std::size_t alignment) + { + if ((alignment - 1) & pos) + { + pos = (pos + alignment) & ~(alignment - 1); + } + } + + bool is_eof() const + { + ASSERT(pos <= view.size, "Out of bounds for eof {} > {}", pos, view.size); + return pos == view.size; + } + + void add_pos(std::size_t amount) + { + pos += amount; + ASSERT(pos <= view.size, "Out of bounds for reading {} > {}", pos, view.size); + } + + template + T &read_as() + { + T *t = current_addr(); + add_pos(sizeof(T)); + return *t; + } + + int32_t read_s4() + { + return read_as(); + } + + std::string_view read_bytes(std::size_t len) + { + auto start = current_addr(); + std::string_view s(start, len); + add_pos(len); + return s; + } + + std::string_view read_string() + { + auto string_length = read_s4(); + auto s = read_bytes(string_length); + return s; + } + + template + void read_bytes_into(T buf, std::size_t len) + { + static_assert(std::is_pointer_v, "T must be a pointer"); + auto start = current_addr(); + std::copy(start, start + len, buf); + add_pos(len); + } + +private: + std::size_t pos{}; +}; + +///////////////////// +/// Debug +///////////////////// + +void WriteDump(ggml_tensor *t) +{ + std::ofstream f("out.txt", std::ios::trunc | std::ios::ate); + std::vector sizes{(size_t *)&t->ne[0], (size_t *)&t->ne[4]}; + + auto total = sizes[0] * sizes[1] * sizes[2] * sizes[3]; + for (auto i = 0; i < total; i++) + { + auto *d = (float *)t->data; + auto dd = d[i]; + f << fmt::format("{},", dd); + } + fmt::print("TOTAL {}\n", total); + f.close(); + exit(-2); +} + +#define DUMP_TENSOR(cur) \ + { \ + auto xxx = cur; \ + xxx = ggml_cont(ctx0, xxx); \ + ggml_set_name(xxx, "dump"); \ + use_scratch(-1); \ + struct ggml_cgraph gf = {}; \ + gf.n_threads = 16; \ + ggml_build_forward_expand(&gf, xxx); \ + ggml_graph_compute(ctx0, &gf); \ + auto *t = ggml_get_tensor(ctx0, "dump"); \ + WriteDump(t); \ + } + +///////////////////// +/// Tensors +///////////////////// + +tl::expected data_type_to_ggml_type(MiniGPT4DataType data_type) +{ + ggml_type type; + switch (data_type) + { + case MiniGPT4DataType::F16: + { + type = GGML_TYPE_F16; + break; + } + case MiniGPT4DataType::F32: + { + type = GGML_TYPE_F32; + break; + } + case MiniGPT4DataType::I32: + { + type = GGML_TYPE_I32; + break; + } + case MiniGPT4DataType::L64: + { + ERR("Unsupported MiniGPT4DataType {}", magic_enum::enum_name(data_type)); + return tl::unexpected(MiniGPT4Error::LoadModelMiniGPT4DataType); + break; + } + case MiniGPT4DataType::Q4_0: + { + type = GGML_TYPE_Q4_0; + break; + } + case MiniGPT4DataType::Q4_1: + { + type = GGML_TYPE_Q4_1; + break; + } + case MiniGPT4DataType::Q5_0: + { + type = GGML_TYPE_Q5_0; + break; + } + case MiniGPT4DataType::Q5_1: + { + type = GGML_TYPE_Q5_1; + break; + } + case MiniGPT4DataType::Q8_0: + { + type = GGML_TYPE_Q8_0; + break; + } + case MiniGPT4DataType::Q8_1: + { + type = GGML_TYPE_Q8_1; + break; + } + case MiniGPT4DataType::Q2_K: + { + type = GGML_TYPE_Q2_K; + break; + } + case MiniGPT4DataType::Q3_K: + { + type = GGML_TYPE_Q3_K; + break; + } + case MiniGPT4DataType::Q4_K: + { + type = GGML_TYPE_Q4_K; + break; + } + case MiniGPT4DataType::Q5_K: + { + type = GGML_TYPE_Q5_K; + break; + } + case MiniGPT4DataType::Q6_K: + { + type = GGML_TYPE_Q6_K; + break; + } + case MiniGPT4DataType::Q8_K: + { + type = GGML_TYPE_Q8_K; + break; + } + default: + { + ERR("Unsupported MiniGPT4DataType {}", magic_enum::enum_name(data_type)); + return tl::unexpected(MiniGPT4Error::LoadModelMiniGPT4DataType); + break; + } + } + return type; +} + +tl::expected ggml_type_to_data_type(ggml_type t) +{ + MiniGPT4DataType data_type; + switch (t) + { + case GGML_TYPE_F16: + { + data_type = MiniGPT4DataType::F16; + break; + } + case GGML_TYPE_F32: + { + data_type = MiniGPT4DataType::F32; + break; + } + case GGML_TYPE_I32: + { + data_type = MiniGPT4DataType::I32; + break; + } + case GGML_TYPE_Q4_0: + { + data_type = MiniGPT4DataType::Q4_0; + break; + } + case GGML_TYPE_Q4_1: + { + data_type = MiniGPT4DataType::Q4_1; + break; + } + case GGML_TYPE_Q5_0: + { + data_type = MiniGPT4DataType::Q5_0; + break; + } + case GGML_TYPE_Q5_1: + { + data_type = MiniGPT4DataType::Q5_1; + break; + } + case GGML_TYPE_Q8_0: + { + data_type = MiniGPT4DataType::Q8_0; + break; + } + case GGML_TYPE_Q8_1: + { + data_type = MiniGPT4DataType::Q8_1; + break; + } + case GGML_TYPE_Q2_K: + { + data_type = MiniGPT4DataType::Q2_K; + break; + } + case GGML_TYPE_Q3_K: + { + data_type = MiniGPT4DataType::Q3_K; + break; + } + case GGML_TYPE_Q4_K: + { + data_type = MiniGPT4DataType::Q4_K; + break; + } + case GGML_TYPE_Q5_K: + { + data_type = MiniGPT4DataType::Q5_K; + break; + } + case GGML_TYPE_Q6_K: + { + data_type = MiniGPT4DataType::Q6_K; + break; + } + case GGML_TYPE_Q8_K: + { + data_type = MiniGPT4DataType::Q8_K; + break; + } + default: + { + ERR("Unsupported MiniGPT4DataType {}", magic_enum::enum_name(t)); + return tl::unexpected(MiniGPT4Error::LoadModelMiniGPT4DataType); + break; + } + } + return data_type; +} + +struct LazyLoadTensor +{ + MMapReader *reader; + std::string name; + std::vector shape; + ggml_type type = ggml_type::GGML_TYPE_COUNT; + + std::size_t pos = 0; + + struct ggml_tensor *tensor = nullptr; + BufferView tensor_buf; + + std::size_t type_size() const + { + switch (type) + { + case ggml_type::GGML_TYPE_F16: + return sizeof(float) / 2; + case ggml_type::GGML_TYPE_F32: + return sizeof(float); + case ggml_type::GGML_TYPE_I32: + return sizeof(int32_t); + default: + return ggml_type_size(type); + } + return 0; + } + + std::size_t total_shape() const + { + std::size_t size = 1; + for (auto i = 0; i < shape.size(); i++) + { + size *= shape[i]; + } + return size; + } + + std::size_t total_size() const + { + if (shape.empty()) + { + return type_size(); + } + std::size_t size = 1; + for (auto i = 0; i < shape.size(); i++) + { + size *= shape[i]; + } + size *= type_size(); + return size; + } + + auto get_size_in_bytes() const + { + // Calculate the size + struct ggml_tensor temp + { + }; + temp.type = type; + auto k = 0; + for (; k < shape.size(); k++) + { + temp.ne[k] = shape[k]; + } + for (; k < 4; k++) + { + temp.ne[k] = 1; + } + return ggml_nbytes(&temp); + } + + auto get_file_address() const + { + return reader->base_addr() + pos; + } + + struct ggml_tensor *operator()(ggml_context *ctx) + { + // Cached + if (tensor) + { + return tensor; + } + + // Create tensors + const auto shape_size = shape.size(); + if (shape_size == 1) + { + tensor = ggml_new_tensor_1d(ctx, type, shape[0]); + } + else if (shape_size == 2) + { + tensor = ggml_new_tensor_2d(ctx, type, shape[0], shape[1]); + } + else if (shape_size == 3) + { + tensor = ggml_new_tensor_3d(ctx, type, shape[0], shape[1], shape[2]); + } + else if (shape_size == 4) + { + tensor = ggml_new_tensor_4d(ctx, type, shape[0], shape[1], shape[2], shape[3]); + } + else + { + PANIC("Layer: {}, didn't expect shape of size {}", name, shape_size); + } + + // Just reference it + tensor_buf.addr = get_file_address(); + tensor_buf.size = get_size_in_bytes(); + + tensor->data = tensor_buf.addr; + return tensor; + } +}; + +class TorchModel +{ +public: + void set_name(std::string_view s) + { + name = s; + } + const std::string &get_name() const + { + return name; + } + + void add_tensor(std::string_view name, LazyLoadTensor tensor) + { + tensors.try_emplace(std::string(name), tensor); + } + + template + LazyLoadTensor &get(Args &&...args) + { + const auto tensor_name = fmt::format(std::forward(args)...); + return operator[](tensor_name); + } + + std::optional get_tensor(const std::string &tensor_name) + { + if (auto found = tensors.find(tensor_name); found != std::end(tensors)) + { + auto &[_, tensor] = *found; + return &tensor; + } + return std::nullopt; + } + + LazyLoadTensor &operator[](const std::string &tensor_name) + { + if (auto tensor = get_tensor(tensor_name)) + { + return **tensor; + } + PANIC("Couldn't find tensor {}", name); + return tensors.begin()->second; + } + + const LazyLoadTensor &operator[](const std::string &tensor_name) const + { + return const_cast(this)->operator[](tensor_name); + } + + auto &get_tensors() { return tensors; } + const auto &get_tensors() const { return tensors; } + +private: + std::string name; + HashMap tensors; +}; + +struct ContextBuffer +{ + void init_context(std::size_t buf_compute_size, + std::size_t buf_scratch_size, + std::size_t num_scratch_buffers = MAX_SCRATCH_BUFFERS) + { + buf_scratch.resize(num_scratch_buffers); + buf_max_size.resize(num_scratch_buffers); + reset_scratch_usage(); + + buf_compute = Buffer(buf_compute_size); + if (buf_scratch_size) + { + for (auto i = 0; i < num_scratch_buffers; i++) + { + buf_scratch[i] = Buffer(buf_scratch_size); + } + } + } + + void use_scratch(int i) + { + size_t last_size = 0; + + if (i == -1) + { + last_size = ggml_set_scratch(ctx, {0, 0, nullptr}); + } + else + { + auto &buf = buf_scratch[i]; + last_size = ggml_set_scratch(ctx, {0, buf.size, buf.addr}); + } + + if (buf_last >= 0) + { + buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + } + + buf_last = i; + } + + auto get_memory_usage(int i) + { + if (i == -1) + { + return ggml_used_mem(ctx); + } + return buf_max_size[static_cast(i)]; + } + + void reset_scratch_usage() + { + buf_last = 0; + for (auto &s : buf_max_size) + { + s = 0; + } + } + + Buffer buf_compute; + std::vector buf_scratch; + int buf_last = 0; + std::vector buf_max_size; + + ggml_context *ctx{}; +}; + +template +struct HasContext +{ + ggml_context *data_ctx = nullptr; + + template + auto operator()(ggml_context *ctx, ggml_tensor *x, Args &&...args) + { + return static_cast(this)->forward(ctx, x, std::forward(args)...); + } +}; + +struct HasContextBase; + +template