From 65f2c1bedc35439714536b0be7723fac8673d67a Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Mon, 13 Feb 2023 11:18:24 -0500 Subject: [PATCH 1/8] adding ruby bindings --- bindings/ruby/ext/extconf.rb | 20 + bindings/ruby/ext/ggml.c | 8616 +++++++++++++++++++++++++++ bindings/ruby/ext/ggml.h | 748 +++ bindings/ruby/ext/ruby_whisper.c | 92 + bindings/ruby/ext/ruby_whisper.h | 14 + bindings/ruby/ext/whisper.cpp | 4814 +++++++++++++++ bindings/ruby/ext/whisper.h | 379 ++ bindings/ruby/tests/test_whisper.rb | 23 + 8 files changed, 14706 insertions(+) create mode 100644 bindings/ruby/ext/extconf.rb create mode 100644 bindings/ruby/ext/ggml.c create mode 100644 bindings/ruby/ext/ggml.h create mode 100644 bindings/ruby/ext/ruby_whisper.c create mode 100644 bindings/ruby/ext/ruby_whisper.h create mode 100644 bindings/ruby/ext/whisper.cpp create mode 100644 bindings/ruby/ext/whisper.h create mode 100644 bindings/ruby/tests/test_whisper.rb diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb new file mode 100644 index 00000000000..0421cb2bb98 --- /dev/null +++ b/bindings/ruby/ext/extconf.rb @@ -0,0 +1,20 @@ +require 'mkmf' +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") + + +# need to use c++ compiler flags +$CXXFLAGS << ' -std=c++11' +# Set to true when building binary gems +if enable_config('static-stdlib', false) + $LDFLAGS << ' -static-libgcc -static-libstdc++' +end + +if enable_config('march-tune-native', false) + $CFLAGS << ' -march=native -mtune=native' + $CXXFLAGS << ' -march=native -mtune=native' +end + +create_makefile('whisper') diff --git a/bindings/ruby/ext/ggml.c b/bindings/ruby/ext/ggml.c new file mode 100644 index 00000000000..d67612c36a3 --- /dev/null +++ b/bindings/ruby/ext/ggml.c @@ -0,0 +1,8616 @@ +#include "ggml.h" + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif + +#if defined _MSC_VER || defined(__MINGW32__) + +#if !defined(__MINGW32__) +#include +#else +// ref: https://github.com/ggerganov/whisper.cpp/issues/168 +#include +#include +#endif + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int* ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int* ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void* unused) { + return (int) WaitForSingleObject(thread, INFINITE); +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else +#include +#include + +typedef void* thread_ret_t; +#endif + +#ifdef __HAIKU__ +#define static_assert(cond, msg) _Static_assert(cond, msg) +#endif + +/*#define GGML_PERF*/ +#define GGML_DEBUG 0 +#define GGML_GELU_FP16 + +#define GGML_SOFT_MAX_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 2 + +#ifdef GGML_USE_ACCELERATE +// uncomment to use vDSP for soft max computation +// note: not sure if it is actually faster +//#define GGML_SOFT_MAX_ACCELERATE +#endif + +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif + +#define UNUSED(x) (void)(x) +#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) + +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#ifdef GGML_USE_ACCELERATE +#include +#elif GGML_USE_OPENBLAS +#include +#endif + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// floating point type used to accumulate sums +typedef double ggml_float; + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#ifdef __ARM_NEON + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#define GGML_COMPUTE_FP16_TO_FP32(x) (x) +#define GGML_COMPUTE_FP32_TO_FP16(x) (x) + +#define GGML_FP16_TO_FP32(x) (x) +#define GGML_FP32_TO_FP16(x) (x) + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#include +#endif +#endif + +#ifdef __F16C__ + +#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static ggml_fp16_t table_gelu_f16[1 << 16]; + +// precomputed exp table for f16 (128 KB) +static ggml_fp16_t table_exp_f16[1 << 16]; + +// precomputed f32 table for f16 (256 KB) +static float table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, +// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. +#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) + +inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return table_f32_f16[s]; +} + +#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +#endif + +// note: do not use these inside ggml.c +// these are meant to be used via the ggml.h API +float ggml_fp16_to_fp32(ggml_fp16_t x) { + return GGML_FP16_TO_FP32(x); +} + +ggml_fp16_t ggml_fp32_to_fp16(float x) { + return GGML_FP32_TO_FP16(x); +} + +// +// timing +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +static int64_t timer_freq; +void ggml_time_init(void) { + LARGE_INTEGER frequency; + QueryPerformanceFrequency(&frequency); + timer_freq = frequency.QuadPart; +} +int64_t ggml_time_ms(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return (t.QuadPart * 1000) / timer_freq; +} +int64_t ggml_time_us(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return (t.QuadPart * 1000000) / timer_freq; +} +#else +void ggml_time_init(void) {} +int64_t ggml_time_ms(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; +} + +int64_t ggml_time_us(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; +} +#endif + +int64_t ggml_cycles(void) { + return clock(); +} + +int64_t ggml_cycles_per_ms(void) { + return CLOCKS_PER_SEC/1000; +} + +#ifdef GGML_PERF +#define ggml_perf_time_ms() ggml_time_ms() +#define ggml_perf_time_us() ggml_time_us() +#define ggml_perf_cycles() ggml_cycles() +#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() +#else +#define ggml_perf_time_ms() 0 +#define ggml_perf_time_us() 0 +#define ggml_perf_cycles() 0 +#define ggml_perf_cycles_per_ms() 0 +#endif + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else +#define CACHE_LINE_SIZE 64 +#endif +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + +// +// simd mappings +// + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 NEON + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#if defined(__ARM_FEATURE_QRDMX) + #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#else + #define GGML_F32x4_REDUCE_ONE(x) \ + (vgetq_lane_f32(x, 0) + \ + vgetq_lane_f32(x, 1) + \ + vgetq_lane_f32(x, 2) + \ + vgetq_lane_f32(x, 3)) +#endif +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ + } \ + res = GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + { \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = vaddvq_f32(vaddq_f32(t0, t1)); \ + } + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX__) + +#define GGML_SIMD + +// F32 AVX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead +// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32 + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vector float +#define GGML_F32x4_ZERO 0.0f +#define GGML_F32x4_SET1 vec_splats +#define GGML_F32x4_LOAD(p) vec_xl(0, p) +#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vec_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vec_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vec_add(x[8*i], x[8*i+4]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +#define GGML_F16_STEP GGML_F32_STEP +#define GGML_F16_EPR GGML_F32_EPR +#define GGML_F16_VEC GGML_F32x4 +#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F16_VEC_FMA GGML_F32x4_FMA +#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// Use vec_xl, not vec_ld, in case the load address is not aligned. +#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ + vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ + vec_extract_fp32_from_shortl(vec_xl(0, p)) +#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ + r[i - GGML_ENDIAN_BYTE(0)]), \ + 0, p - GGML_F16_EPR) + +#elif defined(__wasm_simd128__) + +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE + +#elif defined(__SSE3__) + +#define GGML_SIMD + +// F32 SSE + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __m128 +#define GGML_F32x4_ZERO _mm_setzero_ps() +#define GGML_F32x4_SET1(x) _mm_set1_ps(x) +#define GGML_F32x4_LOAD _mm_loadu_ps +#define GGML_F32x4_STORE _mm_storeu_ps +#if defined(__FMA__) + // TODO: Does this work? + #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) +#else + #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) +#endif +#define GGML_F32x4_ADD _mm_add_ps +#define GGML_F32x4_MUL _mm_mul_ps +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \ + } \ + const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 SSE + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 4 + +static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(x[0]); + tmp[1] = GGML_FP16_TO_FP32(x[1]); + tmp[2] = GGML_FP16_TO_FP32(x[2]); + tmp[3] = GGML_FP16_TO_FP32(x[3]); + + return _mm_loadu_ps(tmp); +} + +static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { + float arr[4]; + + _mm_storeu_ps(arr, y); + + x[0] = GGML_FP32_TO_FP16(arr[0]); + x[1] = GGML_FP32_TO_FP16(arr[1]); + x[2] = GGML_FP32_TO_FP16(arr[2]); + x[3] = GGML_FP32_TO_FP16(arr[3]); +} + +#define GGML_F32Cx4 __m128 +#define GGML_F32Cx4_ZERO _mm_setzero_ps() +#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) +#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) +#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) +#define GGML_F32Cx4_FMA GGML_F32x4_FMA +#define GGML_F32Cx4_ADD _mm_add_ps +#define GGML_F32Cx4_MUL _mm_mul_ps +#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + +#define GGML_F16_VEC GGML_F32Cx4 +#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE + +#endif + +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif + +// +// fundamental operations +// + +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { + ggml_float sumf = 0.0; + +#ifdef GGML_SIMD + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + sumf += x[i]*y[i]; + } +#endif + + *s = sumf; +} + +inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + ggml_float sumf = 0.0; + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + } +#else + for (int i = 0; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + } +#endif + + *s = sumf; +} + +// compute GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + + ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + } + } +#endif + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + GGML_ASSERT(false); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } +inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } +inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } + +static const ggml_float GELU_COEF_A = 0.044715; +static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; + +inline static float ggml_gelu_f32(float x) { + return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_gelu_f16[i16[i]]; + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +} +#endif + +inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += x[i]; + } + *s = sum; +#else + vDSP_sve(x, 1, s, n); +#endif +} + +inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + ggml_float max = -INFINITY; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + *s = max; +#else + vDSP_maxv(x, 1, s, n); +#endif +} + +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + +// +// data types +// + +static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { + sizeof(int8_t ), + sizeof(int16_t), + sizeof(int32_t), + sizeof(ggml_fp16_t), + sizeof(float ), +}; + +static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { + "NONE", + + "DUP", + "ADD", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "SUM", + "MEAN", + "REPEAT", + "ABS", + "SGN", + "NEG", + "STEP", + "RELU", + "GELU", + "NORM", + + "MUL_MAT", + + "SCALE", + "CPY", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "DIAG_MASK_INF", + "SOFT_MAX", + "ROPE", + "CONV_1D_1S", + "CONV_1D_2S", + + "FLASH_ATTN", + "FLASH_FF", +}; + +static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", + + "x", + "x+y", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "Σx", + "Σx/n", + "repeat(x)", + "abs(x)", + "sgn(x)", + "-x", + "step(x)", + "relu(x)", + "gelu(x)", + "norm(x)", + + "X*Y", + + "x*v", + "x-\\>y", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "diag_mask_inf(x)", + "soft_max(x)", + "rope(x)", + "conv_1d_1s(x)", + "conv_1d_2s(x)", + + "flash_attn(x)", + "flash_ff(x)", +}; + +// +// ggml object +// + +struct ggml_object { + size_t offs; + size_t size; + + struct ggml_object * next; + + char padding[8]; +}; + +static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + +static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); +static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); + +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + +// +// compute types +// + +enum ggml_task_type { + GGML_TASK_INIT = 0, + GGML_TASK_COMPUTE, + GGML_TASK_FINALIZE, +}; + +struct ggml_compute_params { + enum ggml_task_type type; + + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; +}; + +// +// ggml state +// + +struct ggml_state { + struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; +}; + +// global state +static struct ggml_state g_state; +static atomic_int g_state_barrier = 0; + +// barrier via spin lock +inline static void ggml_critical_section_start(void) { + int processing = atomic_fetch_add(&g_state_barrier, 1); + + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); // TODO: reconsider this + processing = atomic_fetch_add(&g_state_barrier, 1); + } +} + +// TODO: make this somehow automatically executed +// some sort of "sentry" mechanism +inline static void ggml_critical_section_end(void) { + atomic_fetch_sub(&g_state_barrier, 1); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_print_object(const struct ggml_object * obj) { + GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n", + obj->offs, obj->size, (const void *) obj->next); +} + +void ggml_print_objects(const struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + + while (obj != NULL) { + ggml_print_object(obj); + obj = obj->next; + } + + GGML_PRINT("%s: --- end ---\n", __func__); +} + +int ggml_nelements(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +int ggml_nrows(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +size_t ggml_nbytes(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type]; +} + +size_t ggml_type_size(enum ggml_type type) { + return GGML_TYPE_SIZE[type]; +} + +size_t ggml_element_size(const struct ggml_tensor * tensor) { + return GGML_TYPE_SIZE[tensor->type]; +} + +static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_is_vector(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[1] == tensor->nb[0]*tensor->ne[0] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static inline bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0] ) && + (t0->ne[1] == t1->ne[1] ) && + (t0->ne[2] == t1->ne[2] ) && + (t0->ne[3] == t1->ne[3] ); +} + +// check if t1 can be represented as a repeatition of t0 +static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline int ggml_up32(int n) { + return (n + 31) & ~31; +} + +static inline int ggml_up64(int n) { + return (n + 63) & ~63; +} + +static inline int ggml_up(int n, int m) { + // assert m is a power of 2 + GGML_ASSERT((m & (m - 1)) == 0); + return (n + m - 1) & ~(m - 1); +} + +// assert that pointer is aligned to GGML_MEM_ALIGN +#define ggml_assert_aligned(ptr) \ + assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_context * ggml_init(struct ggml_init_params params) { + // make this function thread safe + ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize GELU, EXP and F32 tables + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + ggml_fp16_t ii; + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); + table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + // initialize g_state + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + g_state = (struct ggml_state) { + /*.contexts =*/ { { 0 } }, + }; + + for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { + g_state.contexts[i].used = false; + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + is_first_call = false; + } + + // find non-used context in g_state + struct ggml_context * ctx = NULL; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (!g_state.contexts[i].used) { + g_state.contexts[i].used = true; + ctx = &g_state.contexts[i].context; + + GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); + break; + } + } + + if (ctx == NULL) { + GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + ggml_critical_section_end(); + + return NULL; + } + + *ctx = (struct ggml_context) { + /*.mem_size =*/ params.mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, + }; + + ggml_assert_aligned(ctx->mem_buffer); + + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + ggml_critical_section_end(); + + return ctx; +} + +void ggml_free(struct ggml_context * ctx) { + // make this function thread safe + ggml_critical_section_start(); + + bool found = false; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (&g_state.contexts[i].context == ctx) { + g_state.contexts[i].used = false; + + GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", + __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size); + + if (ctx->mem_buffer_owned) { + free(ctx->mem_buffer); + } + + found = true; + break; + } + } + + if (!found) { + GGML_PRINT_DEBUG("%s: context not found\n", __func__); + } + + ggml_critical_section_end(); +} + +size_t ggml_used_mem(const struct ggml_context * ctx) { + return ctx->objects_end->offs + ctx->objects_end->size; +} + +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { + const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; + + ctx->scratch = scratch; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_tensor * ggml_new_tensor_impl( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int* ne, + void* data) { + // always insert objects at the end of the context's memory pool + struct ggml_object * obj_cur = ctx->objects_end; + + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; + + size_t size_needed = 0; + + if (data == NULL) { + size_needed += GGML_TYPE_SIZE[type]; + for (int i = 0; i < n_dims; i++) { + size_needed *= ne[i]; + } + // align to GGML_MEM_ALIGN + size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN; + } + + char * const mem_buffer = ctx->mem_buffer; + struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + + if (ctx->scratch.data == NULL || data != NULL) { + size_needed += sizeof(struct ggml_tensor); + + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + }; + } else { + if (ctx->scratch.offs + size_needed > ctx->scratch.size) { + GGML_PRINT("%s: not enough space in the scratch memory\n", __func__); + assert(false); + return NULL; + } + + if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + data = (char * const) ctx->scratch.data + ctx->scratch.offs; + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = sizeof(struct ggml_tensor), + .next = NULL, + }; + + //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed); + + ctx->scratch.offs += size_needed; + } + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); + + struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs); + + ggml_assert_aligned(result); + + *result = (struct ggml_tensor) { + /*.type =*/ type, + /*.n_dims =*/ n_dims, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ GGML_OP_NONE, + /*.is_param =*/ false, + /*.grad =*/ NULL, + /*.src0 =*/ NULL, + /*.src1 =*/ NULL, + /*.opt =*/ { NULL }, + /*.n_tasks =*/ 0, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + /*.data =*/ data == NULL ? (void *)(result + 1) : data, + /*.pad =*/ { 0 }, + }; + + ggml_assert_aligned(result->data); + + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = GGML_TYPE_SIZE[type]; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int * ne) { + return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL); +} + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0) { + return ggml_new_tensor(ctx, type, 1, &ne0); +} + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1) { + const int ne[2] = { ne0, ne1 }; + return ggml_new_tensor(ctx, type, 2, ne); +} + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2) { + const int ne[3] = { ne0, ne1, ne2 }; + return ggml_new_tensor(ctx, type, 3, ne); +} + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2, + int ne3) { + const int ne[4] = { ne0, ne1, ne2, ne3 }; + return ggml_new_tensor(ctx, type, 4, ne); +} + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ctx->scratch = ctx->scratch_save; + + ggml_set_i32(result, value); + + return result; +} + +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + + ctx->scratch = ctx->scratch_save; + + ggml_set_f32(result, value); + + return result; +} + +struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { + return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); +} + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { + memset(tensor->data, 0, ggml_nbytes(tensor)); + return tensor; +} + +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } + + return tensor; +} + +struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } + + return tensor; +} + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +void * ggml_get_data(const struct ggml_tensor * tensor) { + return tensor->data; +} + +float * ggml_get_data_f32(const struct ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + return (float *)(tensor->data); +} + +struct ggml_tensor * ggml_view_tensor( + struct ggml_context * ctx, + const struct ggml_tensor * src) { + return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data); +} + +//////////////////////////////////////////////////////////////////////////////// + +// ggml_dup + +struct ggml_tensor * ggml_dup_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DUP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, true); +} + +// ggml_add + +struct ggml_tensor * ggml_add_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, true); +} + +// ggml_sub + +struct ggml_tensor * ggml_sub_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SUB; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, true); +} + +// ggml_mul + +struct ggml_tensor * ggml_mul_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + assert(is_node == false); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MUL; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, true); +} + +// ggml_div + +struct ggml_tensor * ggml_div_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + assert(is_node == false); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DIV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, true); +} + +// ggml_sqr + +struct ggml_tensor * ggml_sqr_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQR; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, true); +} + +// ggml_sqrt + +struct ggml_tensor * ggml_sqrt_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQRT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, true); +} + +// ggml_sum + +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_SUM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_mean + +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement + is_node = true; + } + + int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne); + + result->op = GGML_OP_MEAN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_repeat + +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_can_repeat(a, b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = GGML_OP_REPEAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_abs + +struct ggml_tensor * ggml_abs_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ABS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_abs_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_abs_impl(ctx, a, true); +} + + +// ggml_sgn + +struct ggml_tensor * ggml_sgn_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SGN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sgn_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sgn_impl(ctx, a, true); +} + +// ggml_neg + +struct ggml_tensor * ggml_neg_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_NEG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_neg_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_neg_impl(ctx, a, true); +} + +// ggml_step + +struct ggml_tensor * ggml_step_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_STEP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_step_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_step_impl(ctx, a, true); +} + +// ggml_relu + +struct ggml_tensor * ggml_relu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_RELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_relu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_relu_impl(ctx, a, true); +} + +// ggml_gelu + +struct ggml_tensor * ggml_gelu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_GELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_impl(ctx, a, true); +} + +// ggml_norm + +struct ggml_tensor * ggml_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, true); +} + +// ggml_mul_mat + +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_can_mul_mat(a, b)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_MUL_MAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_scale + +struct ggml_tensor * ggml_scale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_is_scalar(b)); + assert(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SCALE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, true); +} + +// ggml_cpy + +struct ggml_tensor * ggml_cpy_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + // make a view of the destination + struct ggml_tensor * result = ggml_view_tensor(ctx, b); + + result->op = GGML_OP_CPY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, true); +} + +// ggml_reshape + +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_contiguous(a)); + assert(ggml_is_contiguous(b)); + assert(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1) { + assert(ggml_is_contiguous(a)); + assert(ggml_nelements(a) == ne0*ne1); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[2] = { ne0, ne1 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2) { + assert(ggml_is_contiguous(a)); + assert(ggml_nelements(a) == ne0*ne1*ne2); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[3] = { ne0, ne1, ne2 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_view_1d + +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + size_t offset) { + if (a->grad) { + assert(false); // gradient propagation is not supported + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); + + result->op = GGML_OP_VIEW; + result->grad = NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the offset here? + + return result; +} + +// ggml_view_2d + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + size_t nb1, + size_t offset) { + if (a->grad) { + assert(false); // gradient propagation is not supported + } + + const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + result->op = GGML_OP_VIEW; + result->grad = NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the offset here? + + return result; +} + +// ggml_permute + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + assert(axis0 != axis1); + assert(axis0 != axis2); + assert(axis0 != axis3); + assert(axis1 != axis2); + assert(axis1 != axis3); + assert(axis2 != axis3); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int ne[GGML_MAX_DIMS]; + int nb[GGML_MAX_DIMS]; + + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; + + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; + + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; + + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = GGML_OP_PERMUTE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the permutation here? + + return result; +} + +// ggml_transpose + +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = GGML_OP_TRANSPOSE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_get_rows + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); + + result->op = GGML_OP_GET_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_diag_mask_inf + +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + struct ggml_tensor * b = ggml_new_i32(ctx, n_past); + + result->op = GGML_OP_DIAG_MASK_INF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_soft_max + +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_rope + +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode) { + assert(n_past >= 0); + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = n_dims; + ((int32_t *) b->data)[2] = mode; + + result->op = GGML_OP_ROPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_conv_1d_1s + +struct ggml_tensor * ggml_conv_1d_1s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(b)); + assert(a->ne[1] == b->ne[1]); + assert(a->ne[3] == 1); + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[4] = { b->ne[0], a->ne[2], 1, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + result->op = GGML_OP_CONV_1D_1S; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_conv_1d_2s + +struct ggml_tensor * ggml_conv_1d_2s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(b)); + assert(a->ne[1] == b->ne[1]); + assert(a->ne[3] == 1); + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + result->op = GGML_OP_CONV_1D_2S; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_flash_attn + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked) { + assert(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); + + result->op = GGML_OP_FLASH_ATTN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0); + + return result; +} + +// ggml_flash_ff + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1) { + assert(ggml_can_mul_mat(b0, a)); + // TODO: more checks + + bool is_node = false; + + if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); + + result->op = GGML_OP_FLASH_FF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b0; + result->opt[0] = b1; + result->opt[1] = c0; + result->opt[2] = c1; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor) { + tensor->is_param = true; + + assert(tensor->grad == NULL); + tensor->grad = ggml_dup_tensor(ctx, tensor); +} + +// ggml_compute_forward_dup + +static void ggml_compute_forward_dup_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_contiguous(dst)); + assert(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + if (ggml_is_contiguous(src0) && src0->type == dst->type) { + memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); + return; + } + + if (src0->nb[0] == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + int id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } +} + +static void ggml_compute_forward_dup_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + if (ggml_is_contiguous(src0) && src0->type == dst->type) { + memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); + return; + } + + if (src0->nb[0] == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + int id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } +} + +static void ggml_compute_forward_dup( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_dup_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_dup_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add + +static void ggml_compute_forward_add_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + const int j0 = (n/nth)*ith; + const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); + + for (int j = j0; j < j1; j++) { + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + j*nb1), + (float *) ((char *) src0->data + j*nb01), + (float *) ((char *) src1->data + j*nb11)); + } + } else { + // src1 is not contiguous + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + + dst_ptr[i] = src0_ptr[i] + *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_add( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sub + +static void ggml_compute_forward_sub_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sub_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_sub( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sub_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mul + +static void ggml_compute_forward_mul_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_mul_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_mul( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_div + +static void ggml_compute_forward_div_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_div_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_div( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_div_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sqr + +static void ggml_compute_forward_sqr_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqr( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqr_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sqrt + +static void ggml_compute_forward_sqrt_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqrt( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqrt_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sum + +static void ggml_compute_forward_sum_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) (dst->data), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + } + } + } +} + +static void ggml_compute_forward_sum( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mean + +static void ggml_compute_forward_mean_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void ggml_compute_forward_mean( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mean_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_repeat + +static void ggml_compute_forward_repeat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: implement support for rank > 2 tensors + assert(src0->ne[2] == 1); + assert(src0->ne[3] == 1); + assert( dst->ne[2] == 1); + assert( dst->ne[3] == 1); + + const int nc = dst->ne[0]; + const int nr = dst->ne[1]; + const int nc0 = src0->ne[0]; + const int nr0 = src0->ne[1]; + const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat + const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat + + // TODO: support for transposed / permuted tensors + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i = 0; i < nrr; i++) { + for (int j = 0; j < ncr; j++) { + for (int k = 0; k < nr0; k++) { + ggml_vec_cpy_f32(nc0, + (float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])), + (float *) ((char *) src0->data + ( k)*(src0->nb[1]))); + } + } + } +} + +static void ggml_compute_forward_repeat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_abs + +static void ggml_compute_forward_abs_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_abs( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_abs_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sgn + +static void ggml_compute_forward_sgn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sgn( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sgn_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_neg + +static void ggml_compute_forward_neg_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_neg( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_neg_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_step + +static void ggml_compute_forward_step_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_step( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_step_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_relu + +static void ggml_compute_forward_relu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_relu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_relu_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_gelu + +static void ggml_compute_forward_gelu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_norm + +static void ggml_compute_forward_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const ggml_float eps = 1e-5f; // TODO: make this a parameter + + // TODO: optimize + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float mean = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + mean += x[i00]; + } + + mean /= ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_float sum2 = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + ggml_float v = x[i00] - mean; + y[i00] = v; + sum2 += v*v; + } + + const float scale = 1.0/sqrt(sum2/ne00 + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_norm_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mul_mat + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +static bool ggml_compute_forward_mul_mat_use_blas( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + UNUSED(src0); + + const int ne10 = src1->ne[0]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ( + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) + )) { + //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); + return true; + } + + return false; +} +#endif + +static void ggml_compute_forward_mul_mat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + assert(ne02 == ne12); + assert(ne03 == ne13); + assert(ne2 == ne12); + assert(ne3 == ne13); + + // TODO: we don't support permuted src0 + assert(nb00 == sizeof(float) || nb01 == sizeof(float)); + + // dst cannot be transposed or permuted + assert(nb0 == sizeof(float)); + assert(nb0 <= nb1); + assert(nb1 <= nb2); + assert(nb2 <= nb3); + + assert(ne0 == ne01); + assert(ne1 == ne11); + assert(ne2 == ne02); + assert(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + const float * x = (float *) (src0->data); + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + } + + //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + float * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); + + for (int k = 1; k < nth; k++) { + ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); + } + + return; + } + + if (nb01 >= nb00) { + // TODO: do not support transposed src1 + assert(nb10 == sizeof(float)); + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + ggml_vec_dot_f32(ne00, + (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), + (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f32 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + float * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_vec_mad_f32(ne01, + (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), + (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), + *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + for (int i00 = 0; i00 < ne00; ++i00) { + wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + } + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + + { +#if 1 + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne00, + x, ne00, + 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, + 0.0f, d, ne01); +#endif + } + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + ggml_fp16_t * const wdata = params->wdata; + + int id = 0; + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + } + } + } + } + + GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); + + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + ggml_fp16_t * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); + } + + for (int k = 1; k < nth; k++) { + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); + } + } + + return; + } + + if (nb01 >= nb00) { + // fp16 -> half the size, so divide by 2 + // TODO: do not support transposed src1 + assert(nb10/2 == sizeof(ggml_fp16_t)); + + // parallelize by src0 rows using ggml_vec_dot_f16 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + ggml_fp16_t * wdata = params->wdata; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f16 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + ggml_fp16_t * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + + ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); + } + } + } + } + } + + //int64_t t1 = ggml_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_mat_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_scale + +static void ggml_compute_forward_scale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scale factor + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v); + } +} + +static void ggml_compute_forward_scale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_scale_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_cpy + +static void ggml_compute_forward_cpy( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + +// ggml_compute_forward_reshape + +static void ggml_compute_forward_reshape( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(src0); + UNUSED(dst); +} + +// ggml_compute_forward_view + +static void ggml_compute_forward_view( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_permute + +static void ggml_compute_forward_permute( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_transpose + +static void ggml_compute_forward_transpose( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_get_rows + +static void ggml_compute_forward_get_rows_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + } + } +} + +static void ggml_compute_forward_get_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i*dst->nb[1]), + (float *) ((char *) src0->data + r*src0->nb[1])); + } +} + +static void ggml_compute_forward_get_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_diag_mask_inf + +static void ggml_compute_forward_diag_mask_inf_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 1); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + + // TODO: handle transposed/permuted matrices + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = 0; j < nr; j++) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY; + } + } + } + } +} + +static void ggml_compute_forward_diag_mask_inf( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_soft_max + +static void ggml_compute_forward_soft_max_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *p = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(p[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, p); + + ggml_float sum = 0.0; + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (p[i] == -INFINITY) { + p[i] = 0.0f; + } else { + //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += val; + p[i] = val; + } + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, p, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(p[i])); + assert(!isinf(p[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_rope + +static void ggml_compute_forward_rope_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 3); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + //const int ne0 = src0->ne[0]; + const int ne1 = src0->ne[1]; + const int ne2 = src0->ne[2]; + const int ne3 = src0->ne[3]; + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(float)); + + // TODO: optimize + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = (mode == 0 ? n_past + i2 : i2); + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < n_dims; i0 += 2) { + const double theta = pow(10000.0, ((double)-i0)/n_dims); + + const double cos_theta = cos(p*theta); + const double sin_theta = sin(p*theta); + + const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + double x0 = src[0]; + double x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } + } + } +} + +static void ggml_compute_forward_rope( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_conv_1d_1s + +static void ggml_compute_forward_conv_1d_1s_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_1s_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_1s( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_1d_2s + +static void ggml_compute_forward_conv_1d_2s_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_2s_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_2s( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_attn + +static void ggml_compute_forward_flash_attn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + float sum = 0.0f; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f32(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S); + } + } +} + +static void ggml_compute_forward_flash_attn_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + float sum = 0.0f; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } else { + for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16_unroll(nek1, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_flash_ff + +static void ggml_compute_forward_flash_ff_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, // F16 + const struct ggml_tensor * b0, // F16 fc_w + const struct ggml_tensor * b1, // F32 fc_b + const struct ggml_tensor * c0, // F16 proj_w + const struct ggml_tensor * c1, // F32 proj_b + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int nea0 = a->ne[0]; + const int nea1 = a->ne[1]; + const int nea2 = a->ne[2]; + const int nea3 = a->ne[3]; + + const int neb00 = b0->ne[0]; + const int neb01 = b0->ne[1]; + //const int neb02 = b0->ne[2]; + //const int neb03 = b0->ne[3]; + + const int neb10 = b1->ne[0]; + const int neb11 = b1->ne[1]; + //const int neb12 = b1->ne[2]; + //const int neb13 = b1->ne[3]; + + const int nec00 = c0->ne[0]; + const int nec01 = c0->ne[1]; + //const int nec02 = c0->ne[2]; + //const int nec03 = c0->ne[3]; + + const int nec10 = c1->ne[0]; + const int nec11 = c1->ne[1]; + //const int nec12 = c1->ne[2]; + //const int nec13 = c1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nba0 = a->nb[0]; + const int nba1 = a->nb[1]; + const int nba2 = a->nb[2]; + const int nba3 = a->nb[3]; + + const int nbb00 = b0->nb[0]; + const int nbb01 = b0->nb[1]; + const int nbb02 = b0->nb[2]; + const int nbb03 = b0->nb[3]; + + const int nbb10 = b1->nb[0]; + //const int nbb11 = b1->nb[1]; + //const int nbb12 = b1->nb[2]; + //const int nbb13 = b1->nb[3]; + + const int nbc00 = c0->nb[0]; + const int nbc01 = c0->nb[1]; + const int nbc02 = c0->nb[2]; + const int nbc03 = c0->nb[3]; + + const int nbc10 = c1->nb[0]; + //const int nbc11 = c1->nb[1]; + //const int nbc12 = c1->nb[2]; + //const int nbc13 = c1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = nea0; + //const int N = nea1; + const int M = neb01; + + GGML_ASSERT(ne0 == nea0); + GGML_ASSERT(ne1 == nea1); + GGML_ASSERT(ne2 == nea2); + + GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb10 == sizeof(float)); + GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbc10 == sizeof(float)); + + GGML_ASSERT(neb00 == D); + GGML_ASSERT(neb01 == M); + GGML_ASSERT(neb10 == M); + GGML_ASSERT(neb11 == 1); + + GGML_ASSERT(nec00 == M); + GGML_ASSERT(nec01 == D); + GGML_ASSERT(nec10 == D); + GGML_ASSERT(nec11 == 1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by a rows using ggml_vec_dot_f32 + + // total rows in a + const int nr = nea1*nea2*nea3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // a indices + const int ia3 = ir/(nea2*nea1); + const int ia2 = (ir - ia3*nea2*nea1)/nea1; + const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + ggml_vec_dot_f16(nea0, + S + i1, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + + ggml_vec_add_f32(neb01, S, S, (float *) b1->data); + //ggml_vec_gelu_f32(neb01, S, S); + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + ggml_vec_gelu_f16(neb01, S16, S16); + + { + // dst indices + const int i1 = ia1; + const int i2 = ia2; + const int i3 = ia3; + + for (int ic = 0; ic < nec01; ++ic) { + + ggml_vec_dot_f16(neb01, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), + S16); + } + + ggml_vec_add_f32(nec01, + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) c1->data); + } + } +} + +static void ggml_compute_forward_flash_ff( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b0, + const struct ggml_tensor * b1, + const struct ggml_tensor * c0, + const struct ggml_tensor * c1, + struct ggml_tensor * dst) { + switch (b0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); // TODO + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +///////////////////////////////// + +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + assert(params); + + switch (tensor->op) { + case GGML_OP_DUP: + { + ggml_compute_forward_dup(params, tensor->src0, tensor); + } break; + case GGML_OP_ADD: + { + ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SUB: + { + ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_MUL: + { + ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_DIV: + { + ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SQR: + { + ggml_compute_forward_sqr(params, tensor->src0, tensor); + } break; + case GGML_OP_SQRT: + { + ggml_compute_forward_sqrt(params, tensor->src0, tensor); + } break; + case GGML_OP_SUM: + { + ggml_compute_forward_sum(params, tensor->src0, tensor); + } break; + case GGML_OP_MEAN: + { + ggml_compute_forward_mean(params, tensor->src0, tensor); + } break; + case GGML_OP_REPEAT: + { + ggml_compute_forward_repeat(params, tensor->src0, tensor); + } break; + case GGML_OP_ABS: + { + ggml_compute_forward_abs(params, tensor->src0, tensor); + } break; + case GGML_OP_SGN: + { + ggml_compute_forward_sgn(params, tensor->src0, tensor); + } break; + case GGML_OP_NEG: + { + ggml_compute_forward_neg(params, tensor->src0, tensor); + } break; + case GGML_OP_STEP: + { + ggml_compute_forward_step(params, tensor->src0, tensor); + } break; + case GGML_OP_RELU: + { + ggml_compute_forward_relu(params, tensor->src0, tensor); + } break; + case GGML_OP_GELU: + { + ggml_compute_forward_gelu(params, tensor->src0, tensor); + } break; + case GGML_OP_NORM: + { + ggml_compute_forward_norm(params, tensor->src0, tensor); + } break; + case GGML_OP_MUL_MAT: + { + ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SCALE: + { + ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CPY: + { + ggml_compute_forward_cpy(params, tensor->src0, tensor); + } break; + case GGML_OP_RESHAPE: + { + ggml_compute_forward_reshape(params, tensor->src0, tensor); + } break; + case GGML_OP_VIEW: + { + ggml_compute_forward_view(params, tensor->src0); + } break; + case GGML_OP_PERMUTE: + { + ggml_compute_forward_permute(params, tensor->src0); + } break; + case GGML_OP_TRANSPOSE: + { + ggml_compute_forward_transpose(params, tensor->src0); + } break; + case GGML_OP_GET_ROWS: + { + ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_DIAG_MASK_INF: + { + ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SOFT_MAX: + { + ggml_compute_forward_soft_max(params, tensor->src0, tensor); + } break; + case GGML_OP_ROPE: + { + ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CONV_1D_1S: + { + ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CONV_1D_2S: + { + ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_FLASH_ATTN: + { + int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); + } break; + case GGML_OP_FLASH_FF: + { + ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { + struct ggml_tensor * src0 = tensor->src0; + struct ggml_tensor * src1 = tensor->src1; + + switch (tensor->op) { + case GGML_OP_DUP: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_ADD: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_SUB: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_MUL: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, src1, tensor->grad), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + ggml_mul(ctx, src0, tensor->grad), + inplace); + } + } break; + case GGML_OP_DIV: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, tensor->grad, src1), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_sub_impl(ctx, + src1->grad, + ggml_mul(ctx, + tensor->grad, + ggml_div(ctx, tensor, src1)), + inplace); + } + } break; + case GGML_OP_SQR: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_mul(ctx, src0, tensor->grad), + ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)), + inplace); + } + } break; + case GGML_OP_SQRT: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, + ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor), + tensor), + inplace); + } + } break; + case GGML_OP_SUM: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_repeat(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case GGML_OP_MEAN: + { + assert(false); // TODO: implement + } break; + case GGML_OP_REPEAT: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_sum(ctx, tensor->grad), + inplace); + } + } break; + case GGML_OP_ABS: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_sgn(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_OP_SGN: + { + if (src0->grad) { + // noop + } + } break; + case GGML_OP_NEG: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_STEP: + { + if (src0->grad) { + // noop + } + } break; + case GGML_OP_RELU: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_step(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_OP_GELU: + { + assert(false); // TODO: not implemented + } break; + case GGML_OP_NORM: + { + assert(false); // TODO: not implemented + } break; + case GGML_OP_MUL_MAT: + { + if (src0->grad) { + // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); + assert(false); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + // TODO: fix transpose, the node will break the graph connections + ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad), + inplace); + } + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CPY: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RESHAPE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_VIEW: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_PERMUTE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_TRANSPOSE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_GET_ROWS: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_DIAG_MASK_INF: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_ROPE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_1S: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_2S: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_FLASH_ATTN: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_FLASH_FF: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { + if (node->grad == NULL) { + // this usually happens when we generate intermediate nodes from constants in the backward pass + // it can also happen during forward pass, if the user performs computations with constants + if (node->op != GGML_OP_NONE) { + //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); + } + } + + // check if already visited + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return; + } + } + + for (int i = 0; i < cgraph->n_leafs; i++) { + if (cgraph->leafs[i] == node) { + return; + } + } + + if (node->src0) { + ggml_visit_parents(cgraph, node->src0); + } + + if (node->src1) { + ggml_visit_parents(cgraph, node->src1); + } + + for (int i = 0; i < GGML_MAX_OPT; ++i) { + if (node->opt[i]) { + ggml_visit_parents(cgraph, node->opt[i]); + } + } + + if (node->op == GGML_OP_NONE && node->grad == NULL) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + assert(cgraph->n_leafs < GGML_MAX_NODES); + + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + assert(cgraph->n_nodes < GGML_MAX_NODES); + + cgraph->nodes[cgraph->n_nodes] = node; + cgraph->grads[cgraph->n_nodes] = node->grad; + cgraph->n_nodes++; + } +} + +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { + if (!expand) { + cgraph->n_nodes = 0; + cgraph->n_leafs = 0; + } + + const int n0 = cgraph->n_nodes; + UNUSED(n0); + + ggml_visit_parents(cgraph, tensor); + + const int n_new = cgraph->n_nodes - n0; + GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); + + if (n_new > 0) { + // the last added node should always be starting point + assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + } +} + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + ggml_build_forward_impl(cgraph, tensor, true); +} + +struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { + struct ggml_cgraph result = { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.n_threads =*/ 0, + /*.work_size =*/ 0, + /*.work =*/ NULL, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + ggml_build_forward_impl(&result, tensor, false); + + return result; +} + +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { + struct ggml_cgraph result = *gf; + + assert(gf->n_nodes > 0); + + // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph + if (keep) { + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + node->grad = ggml_dup_tensor(ctx, node); + gf->grads[i] = node->grad; + } + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + // because we detached the grad nodes from the original graph, we can afford inplace operations + if (node->grad) { + ggml_compute_backward(ctx, node, keep); + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->is_param) { + GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + ggml_build_forward_impl(&result, node->grad, true); + } + } + + return result; +} + +// +// thread data +// +// synchronization is done via busy loops +// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops +// + +#ifdef __APPLE__ + +//#include +// +//typedef os_unfair_lock ggml_lock_t; +// +//#define ggml_lock_init(x) UNUSED(x) +//#define ggml_lock_destroy(x) UNUSED(x) +//#define ggml_lock_lock os_unfair_lock_lock +//#define ggml_lock_unlock os_unfair_lock_unlock +// +//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#else + +//typedef pthread_spinlock_t ggml_lock_t; + +//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) +//#define ggml_lock_destroy pthread_spin_destroy +//#define ggml_lock_lock pthread_spin_lock +//#define ggml_lock_unlock pthread_spin_unlock + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#endif + +struct ggml_compute_state_shared { + ggml_lock_t spin; + + int n_threads; + + // synchronization primitives + atomic_int n_ready; + atomic_bool has_work; + atomic_bool stop; // stop all threads +}; + +struct ggml_compute_state { + ggml_thread_t thrd; + + struct ggml_compute_params params; + struct ggml_tensor * node; + + struct ggml_compute_state_shared * shared; +}; + +static thread_ret_t ggml_graph_compute_thread(void * data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + + const int n_threads = state->shared->n_threads; + + while (true) { + if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { + atomic_store(&state->shared->has_work, false); + } else { + while (atomic_load(&state->shared->has_work)) { + if (atomic_load(&state->shared->stop)) { + return 0; + } + ggml_lock_lock (&state->shared->spin); + ggml_lock_unlock(&state->shared->spin); + } + } + + atomic_fetch_sub(&state->shared->n_ready, 1); + + // wait for work + while (!atomic_load(&state->shared->has_work)) { + if (atomic_load(&state->shared->stop)) { + return 0; + } + ggml_lock_lock (&state->shared->spin); + ggml_lock_unlock(&state->shared->spin); + } + + // check if we should stop + if (atomic_load(&state->shared->stop)) { + break; + } + + if (state->node) { + if (state->params.ith < state->params.nth) { + ggml_compute_forward(&state->params, state->node); + } + + state->node = NULL; + } else { + break; + } + } + + return 0; +} + +void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { + if (cgraph->n_threads <= 0) { + cgraph->n_threads = 8; + } + + const int n_threads = cgraph->n_threads; + + struct ggml_compute_state_shared state_shared = { + /*.spin =*/ GGML_LOCK_INITIALIZER, + /*.n_threads =*/ n_threads, + /*.n_ready =*/ 0, + /*.has_work =*/ false, + /*.stop =*/ false, + }; + struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; + + // create thread pool + if (n_threads > 1) { + ggml_lock_init(&state_shared.spin); + + atomic_store(&state_shared.has_work, true); + + for (int j = 0; j < n_threads - 1; j++) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .params = { + .type = GGML_TASK_COMPUTE, + .ith = j + 1, + .nth = n_threads, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }, + .node = NULL, + .shared = &state_shared, + }; + + int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + assert(rc == 0); + UNUSED(rc); + } + } + + // initialize tasks + work buffer + { + size_t work_size = 0; + + // thread scheduling for the different operations + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_DUP: + { + node->n_tasks = 1; + } break; + case GGML_OP_ADD: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SUM: + case GGML_OP_MEAN: + case GGML_OP_REPEAT: + case GGML_OP_ABS: + case GGML_OP_SGN: + case GGML_OP_NEG: + case GGML_OP_STEP: + case GGML_OP_RELU: + { + node->n_tasks = 1; + } break; + case GGML_OP_GELU: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_NORM: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_MUL_MAT: + { + node->n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + //const int nr0 = ggml_nrows(node->src0); + //const int nr1 = ggml_nrows(node->src1); + + //node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); + + size_t cur = 0; + + // TODO: better way to determine if the matrix is transposed + if (node->src0->nb[1] < node->src0->nb[0]) { + cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) + } else { + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); + //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); + //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); + //printf("cur = %zu\n", cur); + } else { + cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); + } +#else + cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); +#endif + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = 0; + } else { + GGML_ASSERT(false); + } + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_SCALE: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_CPY: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_DIAG_MASK_INF: + { + node->n_tasks = 1; + } break; + case GGML_OP_SOFT_MAX: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_ROPE: + { + node->n_tasks = 1; + } break; + case GGML_OP_CONV_1D_1S: + case GGML_OP_CONV_1D_2S: + { + node->n_tasks = n_threads; + + GGML_ASSERT(node->src0->ne[3] == 1); + GGML_ASSERT(node->src1->ne[2] == 1); + GGML_ASSERT(node->src1->ne[3] == 1); + + size_t cur = 0; + const int nk = node->src0->ne[0]; + + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else { + GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_FF: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_NONE: + { + node->n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + assert(false); + } break; + } + } + + if (cgraph->work != NULL && work_size > cgraph->work_size) { + assert(false); // TODO: better handling + } + + if (work_size > 0 && cgraph->work == NULL) { + cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1); + + GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); + cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size); + } + } + + const int64_t perf_start_cycles = ggml_perf_cycles(); + const int64_t perf_start_time_us = ggml_perf_time_us(); + + for (int i = 0; i < cgraph->n_nodes; i++) { + GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); + + struct ggml_tensor * node = cgraph->nodes[i]; + + // TODO: this could be used to avoid unnecessary computations, but it needs to be improved + //if (node->grad == NULL && node->perf_runs > 0) { + // continue; + //} + + const int64_t perf_node_start_cycles = ggml_perf_cycles(); + const int64_t perf_node_start_time_us = ggml_perf_time_us(); + + // INIT + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_INIT, + /*.ith =*/ 0, + /*.nth =*/ node->n_tasks, + /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, + /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + }; + + ggml_compute_forward(¶ms, node); + + // COMPUTE + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + // launch thread pool + for (int j = 0; j < n_threads - 1; j++) { + workers[j].params = (struct ggml_compute_params) { + .type = GGML_TASK_COMPUTE, + .ith = j + 1, + .nth = node->n_tasks, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }; + workers[j].node = node; + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) > 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_store(&state_shared.has_work, true); + } + + params.type = GGML_TASK_COMPUTE; + ggml_compute_forward(¶ms, node); + + // wait for thread pool + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) != 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + } + + // FINALIZE + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + // launch thread pool + for (int j = 0; j < n_threads - 1; j++) { + workers[j].params = (struct ggml_compute_params) { + .type = GGML_TASK_FINALIZE, + .ith = j + 1, + .nth = node->n_tasks, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }; + workers[j].node = node; + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) > 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_store(&state_shared.has_work, true); + } + + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + + // wait for thread pool + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) != 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + } + + // performance stats (node) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += perf_cycles_cur; + node->perf_time_us += perf_time_us_cur; + } + } + + // join thread pool + if (n_threads > 1) { + atomic_store(&state_shared.stop, true); + atomic_store(&state_shared.has_work, true); + + for (int j = 0; j < n_threads - 1; j++) { + int rc = ggml_thread_join(workers[j].thrd, NULL); + assert(rc == 0); + UNUSED(rc); + } + + ggml_lock_destroy(&state_shared.spin); + } + + // performance stats (graph) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; + + cgraph->perf_runs++; + cgraph->perf_cycles += perf_cycles_cur; + cgraph->perf_time_us += perf_time_us_cur; + + GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", + __func__, cgraph->perf_runs, + (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), + (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, + (double) perf_time_us_cur / 1000.0, + (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); + } +} + +void ggml_graph_reset(struct ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + ggml_set_zero(grad); + } + } +} + +void ggml_graph_print(const struct ggml_cgraph * cgraph) { + int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; + + GGML_PRINT("=== GRAPH ===\n"); + + GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads); + GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size); + + GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + perf_total_per_op_us[node->op] += node->perf_time_us; + + GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + i, + node->ne[0], node->ne[1], node->ne[2], + GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + (double) node->perf_cycles / (double) ggml_cycles_per_ms(), + (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, + (double) node->perf_time_us / 1000.0, + (double) node->perf_time_us / 1000.0 / node->perf_runs); + } + + GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * node = cgraph->leafs[i]; + + GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n", + i, + node->ne[0], node->ne[1], + GGML_OP_LABEL[node->op]); + } + + for (int i = 0; i < GGML_OP_COUNT; i++) { + GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0); + } + + GGML_PRINT("========================================\n"); +} + +// check if node is part of the graph +static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + if (cgraph == NULL) { + return true; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; + } + } + + return false; +} + +static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * parent = cgraph->nodes[i]; + + if (parent->grad == node) { + return parent; + } + } + + return NULL; +} + +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { + char color[16]; + + FILE * fp = fopen(filename, "w"); + assert(fp); + + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = LR;\n"); + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + if (ggml_graph_get_parent(gb, node) != NULL) { + continue; + } + + if (node->is_param) { + snprintf(color, sizeof(color), "yellow"); + } else if (node->grad) { + if (ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); + } else { + snprintf(color, sizeof(color), "lightblue"); + } + } else { + snprintf(color, sizeof(color), "white"); + } + + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"%d [%d, %d] | %s", + (void *) node, color, + i, node->ne[0], node->ne[1], + GGML_OP_SYMBOL[node->op]); + + if (node->grad) { + fprintf(fp, " | %s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]); + } else { + fprintf(fp, "\"; ]\n"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + snprintf(color, sizeof(color), "pink"); + + if (ggml_nelements(node) == 1) { + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"%.1e\"; ]\n", + (void *) node, color, ggml_get_f32_1d(node, 0)); + } else { + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"CONST %d [%d, %d]\"; ]\n", + (void *) node, color, + i, node->ne[0], node->ne[1]); + } + } + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + struct ggml_tensor * parent = ggml_graph_get_parent(gb, node); + + if (node->src0) { + struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0); + + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n", + parent0 ? (void *) parent0 : (void *) node->src0, + parent0 ? "g" : "x", + parent ? (void *) parent : (void *) node, + parent ? "g" : "x", + parent ? "empty" : "vee", + parent ? "dashed" : "solid"); + } + + if (node->src1) { + struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1); + + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n", + parent1 ? (void *) parent1 : (void *) node->src1, + parent1 ? "g" : "x", + parent ? (void *) parent : (void *) node, + parent ? "g" : "x", + parent ? "empty" : "vee", + parent ? "dashed" : "solid"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + if (node->src0) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n", + (void *) node->src0, "x", + (void *) node, "x"); + } + + if (node->src1) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n", + (void *) node->src1, "x", + (void *) node, "x"); + } + } + + fprintf(fp, "}\n"); + + fclose(fp); + + GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to set tensor from array + for (int j = 0; j < ne; ++j) { + ggml_set_f32_1d(ps[p], j, x[i++]); + } + } +} + +static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int j = 0; j < ne; ++j) { + x[i++] = ggml_get_f32_1d(ps[p], j); + } + } +} + +static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int j = 0; j < ne; ++j) { + g[i++] = ggml_get_f32_1d(ps[p]->grad, j); + } + } +} + +// +// ADAM +// +// ref: https://arxiv.org/pdf/1412.6980.pdf +// + +static enum ggml_opt_result ggml_opt_adam( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb) { + assert(ggml_is_scalar(f)); + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + assert(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + // constants + const float alpha = params.adam.alpha; + const float beta1 = params.adam.beta1; + const float beta2 = params.adam.beta2; + const float eps = params.adam.eps; + + float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters + float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient + float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared + float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment + float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment + float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat + float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat + + float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values + + // initialize + ggml_vec_set_f32(nx, m, 0.0f); + ggml_vec_set_f32(nx, v, 0.0f); + + // update view + ggml_opt_get_params(np, ps, x); + + // compute the function value + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + float fx_prev = ggml_get_f32_1d(f, 0); + if (pf) { + pf[0] = fx_prev; + } + + int n_no_improvement = 0; + float fx_best = fx_prev; + + // run the optimizer + for (int t = 0; t < params.adam.n_iter; ++t) { + GGML_PRINT_DEBUG ("=== iter %d ===\n", t); + + GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); + GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); + + for (int i = 0; i < np; ++i) { + GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, + ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); + } + + const int64_t t_start_wall = ggml_time_us(); + const int64_t t_start_cpu = ggml_cycles(); + UNUSED(t_start_wall); + UNUSED(t_start_cpu); + + { + // update the gradient + ggml_opt_get_grad(np, ps, g1); + + // m_t = beta1*m_t-1 + (1 - beta1)*g_t + ggml_vec_scale_f32(nx, m, beta1); + ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); + + // g2 = g1^2 + ggml_vec_sqr_f32 (nx, g2, g1); + + // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 + ggml_vec_scale_f32(nx, v, beta2); + ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); + + // m^hat = m_t / (1 - beta1^t) + // v^hat = v_t / (1 - beta2^t) + // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps) + ggml_vec_cpy_f32 (nx, mh, m); + ggml_vec_cpy_f32 (nx, vh, v); + + ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1))); + ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1))); + + ggml_vec_sqrt_f32 (nx, vh, vh); + ggml_vec_acc1_f32 (nx, vh, eps); + + ggml_vec_div_f32 (nx, mh, mh, vh); + ggml_vec_sub_f32 (nx, x, x, mh); + + // update the parameters + ggml_opt_set_params(np, ps, x); + } + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + const float fx = ggml_get_f32_1d(f, 0); + + // check convergence + if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) { + GGML_PRINT_DEBUG("converged\n"); + + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= t) { + const float rate = (pf[t%params.past] - fx)/fx; + + if (fabs(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[t%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx_best > fx) { + fx_best = fx; + n_no_improvement = 0; + } else { + ++n_no_improvement; + + if (n_no_improvement >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + fx_prev = fx; + + { + const int64_t t_end_cpu = ggml_cycles(); + GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); + UNUSED(t_end_cpu); + + const int64_t t_end_wall = ggml_time_us(); + GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); + UNUSED(t_end_wall); + } + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +// +// L-BFGS +// +// the L-BFGS implementation below is based on the following implementation: +// +// https://github.com/chokkan/liblbfgs +// + +struct ggml_lbfgs_iteration_data { + float alpha; + float ys; + float * s; + float * y; +}; + +static enum ggml_opt_result linesearch_backtracking( + struct ggml_context * ctx, + const struct ggml_opt_params * params, + int nx, + float * x, + float * fx, + float * g, + float * d, + float * step, + const float * xp, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + const int np, + struct ggml_tensor * ps[]) { + int count = 0; + + float width = 0.0f; + float dg = 0.0f; + float finit = 0.0f; + float dginit = 0.0f; + float dgtest = 0.0f; + + const float dec = 0.5f; + const float inc = 2.1f; + + if (*step <= 0.) { + return GGML_LINESEARCH_INVALID_PARAMETERS; + } + + // compute the initial gradient in the search direction + ggml_vec_dot_f32(nx, &dginit, g, d); + + // make sure that d points to a descent direction + if (0 < dginit) { + return GGML_LINESEARCH_FAIL; + } + + // initialize local variables + finit = *fx; + dgtest = params->lbfgs.ftol*dginit; + + while (true) { + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_mad_f32(nx, x, d, *step); + + // evaluate the function and gradient values + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + ggml_opt_get_grad(np, ps, g); + + *fx = ggml_get_f32_1d(f, 0); + } + + ++count; + + if (*fx > finit + (*step)*dgtest) { + width = dec; + } else { + // Armijo condition is satisfied + if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { + return count; + } + + ggml_vec_dot_f32(nx, &dg, g, d); + + // check the Wolfe condition + if (dg < params->lbfgs.wolfe * dginit) { + width = inc; + } else { + if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { + // regular Wolfe conditions + return count; + } + + if(dg > -params->lbfgs.wolfe*dginit) { + width = dec; + } else { + // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) + return count; + } + return count; + } + } + + if (*step < params->lbfgs.min_step) { + return GGML_LINESEARCH_MINIMUM_STEP; + } + if (*step > params->lbfgs.max_step) { + return GGML_LINESEARCH_MAXIMUM_STEP; + } + if (params->lbfgs.max_linesearch <= count) { + return GGML_LINESEARCH_MAXIMUM_ITERATIONS; + } + + (*step) *= width; + } + + return GGML_LINESEARCH_FAIL; +} + +static enum ggml_opt_result ggml_opt_lbfgs( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb) { + if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || + params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) { + return GGML_OPT_INVALID_WOLFE; + } + } + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + const int m = params.lbfgs.m; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + assert(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters + float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters + float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient + float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient + float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction + + float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values + + float fx = 0.0f; // cost function value + float xnorm = 0.0f; // ||x|| + float gnorm = 0.0f; // ||g|| + float step = 0.0f; + + // initialize x from the graph nodes + ggml_opt_get_params(np, ps, x); + + // the L-BFGS memory + struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m); + + for (int i = 0; i < m; ++i) { + lm[i].alpha = 0.0f; + lm[i].ys = 0.0f; + lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; + lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; + } + + // evaluate the function value and its gradient + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + ggml_opt_get_grad(np, ps, g); + + fx = ggml_get_f32_1d(f, 0); + } + + if (pf) { + pf[0] = fx; + } + + float fx_best = fx; + + // search direction = -gradient + ggml_vec_neg_f32(nx, d, g); + + // ||x||, ||g|| + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + + // already optimized + if (gnorm/xnorm <= params.lbfgs.eps) { + return GGML_OPT_OK; + } + + // initial step + ggml_vec_norm_inv_f32(nx, &step, d); + + int j = 0; + int k = 1; + int ls = 0; + int end = 0; + int bound = 0; + int n_no_improvement = 0; + + float ys = 0.0f; + float yy = 0.0f; + float beta = 0.0f; + + while (true) { + // store the current position and gradient vectors + ggml_vec_cpy_f32(nx, xp, x); + ggml_vec_cpy_f32(nx, gp, g); + + ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps); + + if (ls < 0) { + // linesearch failed - go back to the previous point and return + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_cpy_f32(nx, g, gp); + + return ls; + } + + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + + if (xnorm < 1.0) { + xnorm = 1.0; + } + if (gnorm/xnorm <= params.lbfgs.eps) { + // converged + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= k) { + const float rate = (pf[k%params.past] - fx)/fx; + + if (fabs(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[k%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx < fx_best) { + fx_best = fx; + n_no_improvement = 0; + } else { + n_no_improvement++; + + if (n_no_improvement >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) { + // reached the maximum number of iterations + return GGML_OPT_DID_NOT_CONVERGE; + } + + // update vectors s and y: + // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. + // y_{k+1} = g_{k+1} - g_{k}. + // + ggml_vec_sub_f32(nx, lm[end].s, x, xp); + ggml_vec_sub_f32(nx, lm[end].y, g, gp); + + // compute scalars ys and yy: + // ys = y^t \cdot s -> 1 / \rho. + // yy = y^t \cdot y. + // + ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s); + ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y); + + lm[end].ys = ys; + + // find new search direction + // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS + + bound = (m <= k) ? m : k; + k++; + end = (end + 1)%m; + + // initialize search direction with -g + ggml_vec_neg_f32(nx, d, g); + + j = end; + for (int i = 0; i < bound; ++i) { + j = (j + m - 1) % m; + // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} + ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d); + lm[j].alpha /= lm[j].ys; + // q_{i} = q_{i+1} - \alpha_{i} y_{i} + ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha); + } + + ggml_vec_scale_f32(nx, d, ys/yy); + + for (int i = 0; i < bound; ++i) { + // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} + ggml_vec_dot_f32(nx, &beta, lm[j].y, d); + beta /= lm[j].ys; + // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} + ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta); + j = (j + 1)%m; + } + + step = 1.0; + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { + struct ggml_opt_params result; + + switch (type) { + case GGML_OPT_ADAM: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_ADAM, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 100, + + .print_forward_graph = true, + .print_backward_graph = true, + + .adam = { + .n_iter = 10000, + .alpha = 0.001f, + .beta1 = 0.9f, + .beta2 = 0.999f, + .eps = 1e-8f, + .eps_f = 1e-5f, + .eps_g = 1e-3f, + }, + }; + } break; + case GGML_OPT_LBFGS: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_LBFGS, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 0, + + .print_forward_graph = true, + .print_backward_graph = true, + + .lbfgs = { + .m = 6, + .n_iter = 100, + .max_linesearch = 20, + + .eps = 1e-5f, + .ftol = 1e-4f, + .wolfe = 0.9f, + .min_step = 1e-20f, + .max_step = 1e+20f, + + .linesearch = GGML_LINESEARCH_DEFAULT, + }, + }; + } break; + } + + return result; +} + +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f) { + bool free_ctx = false; + if (ctx == NULL) { + struct ggml_init_params params_ctx = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + }; + + ctx = ggml_init(params_ctx); + if (ctx == NULL) { + return GGML_OPT_NO_CONTEXT; + } + + free_ctx = true; + } + + enum ggml_opt_result result = GGML_OPT_OK; + + // build forward + backward compute graphs + struct ggml_cgraph gf = ggml_build_forward (f); + struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false); + + switch (params.type) { + case GGML_OPT_ADAM: + { + result = ggml_opt_adam(ctx, params, f, &gf, &gb); + } break; + case GGML_OPT_LBFGS: + { + result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb); + } break; + } + + if (params.print_forward_graph) { + ggml_graph_print (&gf); + ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot"); + } + + if (params.print_backward_graph) { + ggml_graph_print (&gb); + ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot"); + } + + if (free_ctx) { + ggml_free(ctx); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_neon(void) { +#if defined(__ARM_NEON) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_blas(void) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_sse3(void) { +#if defined(__SSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_vsx(void) { +#if defined(__POWER9_VECTOR__) + return 1; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/bindings/ruby/ext/ggml.h b/bindings/ruby/ext/ggml.h new file mode 100644 index 00000000000..18f317bec04 --- /dev/null +++ b/bindings/ruby/ext/ggml.h @@ -0,0 +1,748 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph gf = ggml_build_forward(f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute(ctx0, &gf); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3); +// +// // a[1, 2] = 1.0f; +// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f; +// +// // a[2, 0] = 2.0f; +// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f; +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include + +#define GGML_MAX_DIMS 4 +#define GGML_MAX_NODES 4096 +#define GGML_MAX_PARAMS 16 +#define GGML_MAX_CONTEXTS 64 +#define GGML_MAX_OPT 4 + +#ifdef __ARM_NEON +// we use the built-in 16-bit float type +typedef __fp16 ggml_fp16_t; +#else +typedef uint16_t ggml_fp16_t; +#endif + +// convert FP16 <-> FP32 +float ggml_fp16_to_fp32(ggml_fp16_t x); +ggml_fp16_t ggml_fp32_to_fp16(float x); + +struct ggml_object; +struct ggml_context; + +enum ggml_type { + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_F16, + GGML_TYPE_F32, + GGML_TYPE_COUNT, +}; + +// available tensor operations: +enum ggml_op { + GGML_OP_NONE = 0, + + GGML_OP_DUP, + GGML_OP_ADD, + GGML_OP_SUB, + GGML_OP_MUL, + GGML_OP_DIV, + GGML_OP_SQR, + GGML_OP_SQRT, + GGML_OP_SUM, + GGML_OP_MEAN, + GGML_OP_REPEAT, + GGML_OP_ABS, + GGML_OP_SGN, + GGML_OP_NEG, + GGML_OP_STEP, + GGML_OP_RELU, + GGML_OP_GELU, + GGML_OP_NORM, // normalize + + GGML_OP_MUL_MAT, + + GGML_OP_SCALE, + GGML_OP_CPY, + GGML_OP_RESHAPE, + GGML_OP_VIEW, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_DIAG_MASK_INF, + GGML_OP_SOFT_MAX, + GGML_OP_ROPE, + GGML_OP_CONV_1D_1S, + GGML_OP_CONV_1D_2S, + + GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_FF, + + GGML_OP_COUNT, +}; + +// n-dimensional tensor +struct ggml_tensor { + enum ggml_type type; + + int n_dims; + int ne[GGML_MAX_DIMS]; // number of elements + size_t nb[GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = sizeof(type) + // nb[1] = nb[0] * ne[0] + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum ggml_op op; + + bool is_param; + + struct ggml_tensor * grad; + struct ggml_tensor * src0; + struct ggml_tensor * src1; + struct ggml_tensor * opt[GGML_MAX_OPT]; + + // thread scheduling + int n_tasks; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + + void * data; + char padding[8]; +}; + +// computation graph +struct ggml_cgraph { + int n_nodes; + int n_leafs; + int n_threads; + + size_t work_size; + struct ggml_tensor * work; + + struct ggml_tensor * nodes[GGML_MAX_NODES]; + struct ggml_tensor * grads[GGML_MAX_NODES]; + struct ggml_tensor * leafs[GGML_MAX_NODES]; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; +}; + +// scratch buffer +struct ggml_scratch { + size_t offs; + size_t size; + void * data; +}; + +struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally +}; + +void ggml_time_init(void); // call this once at the beginning of the program +int64_t ggml_time_ms(void); +int64_t ggml_time_us(void); +int64_t ggml_cycles(void); +int64_t ggml_cycles_per_ms(void); + +void ggml_print_object (const struct ggml_object * obj); +void ggml_print_objects(const struct ggml_context * ctx); + +int ggml_nelements(const struct ggml_tensor * tensor); +size_t ggml_nbytes (const struct ggml_tensor * tensor); + +size_t ggml_type_size (enum ggml_type type); +size_t ggml_element_size(const struct ggml_tensor * tensor); + +struct ggml_context * ggml_init(struct ggml_init_params params); +void ggml_free(struct ggml_context * ctx); + +size_t ggml_used_mem(const struct ggml_context * ctx); + +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int *ne); + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0); + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1); + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2); + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2, + int ne3); + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + +struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); +struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); +struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + void * ggml_get_data (const struct ggml_tensor * tensor); +float * ggml_get_data_f32(const struct ggml_tensor * tensor); + +// +// operations on tensors with backpropagation +// + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// return scalar +// TODO: compute sum along rows +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// mean along rows +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// if a is the same shape as b, and a is not parameter, return a +// otherwise, return a new tensor: repeat(a) to fit in b +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// TODO: double-check this computation is correct +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// normalize along rows +// TODO: eps is hardcoded to 1e-5 for now +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// A: m rows, n columns +// B: p rows, n columns (i.e. we transpose it internally) +// result is m columns, p rows +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// +// operations on tensors without backpropagation +// + +// in-place, returns view(a) +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// a -> b, return view(b) +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// return view(a), b specifies the new shape +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// return view(a) +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1); + +// return view(a) +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2); + +// offset in bytes +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + size_t offset); + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + size_t nb1, // row stride in bytes + size_t offset); + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + +// alias for ggml_permute(ctx, a, 1, 0, 2, 3) +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// set elements above the diagonal to -INF +// in-place, returns view(a) +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + +// in-place, returns view(a) +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// rotary position embedding +// in-place, returns view(a) +// if mode == 1, skip n_past elements +// TODO: avoid creating a new tensor every time +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode); + +// padding = 1 +// TODO: we don't support extra parameters for now +// that's why we are hard-coding the stride, padding, and dilation +// not great .. +struct ggml_tensor * ggml_conv_1d_1s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_conv_1d_2s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked); + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1); + +// +// automatic differentiation +// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor); + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + +struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); + +void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); +void ggml_graph_reset (struct ggml_cgraph * cgraph); + +// print info and performance information for the graph +void ggml_graph_print(const struct ggml_cgraph * cgraph); + +// dump the graph into a file using the dot format +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + +// +// optimization +// + +// optimization methods +enum ggml_opt_type { + GGML_OPT_ADAM, + GGML_OPT_LBFGS, +}; + +// linesearch methods +enum ggml_linesearch { + GGML_LINESEARCH_DEFAULT = 1, + + GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, + GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, + GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, +}; + +// optimization return values +enum ggml_opt_result { + GGML_OPT_OK = 0, + GGML_OPT_DID_NOT_CONVERGE, + GGML_OPT_NO_CONTEXT, + GGML_OPT_INVALID_WOLFE, + GGML_OPT_FAIL, + + GGML_LINESEARCH_FAIL = -128, + GGML_LINESEARCH_MINIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_ITERATIONS, + GGML_LINESEARCH_INVALID_PARAMETERS, +}; + +// optimization parameters +// +// see ggml.c (ggml_opt_default_params) for default values +// +struct ggml_opt_params { + enum ggml_opt_type type; + + int n_threads; + + // delta-based convergence test + // + // if past == 0 - disabled + // if past > 0: + // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) + // + int past; + float delta; + + // maximum number of iterations without improvement + // + // if 0 - disabled + // if > 0: + // assume convergence if no cost improvement in this number of iterations + // + int max_no_improvement; + + bool print_forward_graph; + bool print_backward_graph; + + // ADAM parameters + struct { + int n_iter; + + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum ggml_linesearch linesearch; + } lbfgs; +}; + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); + +// optimize the function defined by the tensor f +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f); + +// +// system info +// + +int ggml_cpu_has_avx(void); +int ggml_cpu_has_avx2(void); +int ggml_cpu_has_avx512(void); +int ggml_cpu_has_fma(void); +int ggml_cpu_has_neon(void); +int ggml_cpu_has_arm_fma(void); +int ggml_cpu_has_f16c(void); +int ggml_cpu_has_fp16_va(void); +int ggml_cpu_has_wasm_simd(void); +int ggml_cpu_has_blas(void); +int ggml_cpu_has_sse3(void); +int ggml_cpu_has_vsx(void); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c new file mode 100644 index 00000000000..4abfab524b4 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper.c @@ -0,0 +1,92 @@ +#include +#include "ruby_whisper.h" + +VALUE mWhisper; +VALUE cContext; +VALUE cParams; + +static void ruby_whisper_free(ruby_whisper *rw) { + if (rw->context) { + whisper_free(rw->context); + rw->context = NULL; + } +} +static void ruby_whisper_params_free(ruby_whisper_params *rwp) { + if (rwp->params) { + free(rwp->params); + rwp->params = NULL; + } +} + +void rb_whisper_mark(ruby_whisper *rw) { + // call rb_gc_mark on any ruby references in rw +} + +void rb_whisper_free(ruby_whisper *rw) { + ruby_whisper_free(rw); + free(rw); +} + +void rb_whisper_params_mark(ruby_whisper_params *rwp) { +} + +void rb_whisper_params_free(ruby_whisper_params *rwp) { + ruby_whisper_params_free(rwp); + free(rwp); +} + +static VALUE ruby_whisper_allocate(VALUE klass) { + ruby_whisper *rw; + rw = ALLOC(ruby_whisper); + rw->context = NULL; + return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); +} + +static VALUE ruby_whisper_params_allocate(VALUE klass) { + ruby_whisper_params *rwp; + rwp = ALLOC(ruby_whisper_params); + rwp->params = ALLOC(struct whisper_full_params); + return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); +} + +static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { + ruby_whisper *rw; + VALUE whisper_model_file_path; + + // TODO: we can support init from buffer here too maybe another ruby object to expose + rb_scan_args(argc, argv, "01", &whisper_model_file_path); + Data_Get_Struct(self, ruby_whisper, rw); + + if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); + } + rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path)); + return self; +} + +/* + * params.auto_detection = true|false + */ +static VALUE ruby_whisper_params_set_auto_detection(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->params->language = "auto"; + } else { + rwp->params->language = NULL; + } + return value; +} + +void Init_whisper() { + mWhisper = rb_define_module("Whisper"); + cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); + cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); + + rb_define_alloc_func(cContext, ruby_whisper_allocate); + rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); + + rb_define_alloc_func(cParams, ruby_whisper_params_allocate); + + rb_define_method(cParams, "auto_detection=", ruby_whisper_params_set_auto_detection, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h new file mode 100644 index 00000000000..246d13334a1 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper.h @@ -0,0 +1,14 @@ +#ifndef __RUBY_WHISPER_H +#define __RUBY_WHISPER_H + +#include "whisper.h" + +typedef struct { + struct whisper_context *context; +} ruby_whisper; + +typedef struct { + struct whisper_full_params *params; +} ruby_whisper_params; + +#endif diff --git a/bindings/ruby/ext/whisper.cpp b/bindings/ruby/ext/whisper.cpp new file mode 100644 index 00000000000..24e16bd5cf9 --- /dev/null +++ b/bindings/ruby/ext/whisper.cpp @@ -0,0 +1,4814 @@ +#define WHISPER_BUILD +#include "whisper.h" + +#include "ggml.h" + +#include +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(GGML_BIG_ENDIAN) +#include + +template +static T byteswap(T value) { + return std::byteswap(value); +} + +template<> +float byteswap(float value) { + return std::bit_cast(byteswap(std::bit_cast(value))); +} + +template +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(tensor); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#define WHISPER_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define WHISPER_DEBUG + +#if defined(WHISPER_DEBUG) +#define WHISPER_PRINT_DEBUG(...) \ + do { \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) +#else +#define WHISPER_PRINT_DEBUG(...) +#endif + +#define WHISPER_USE_FLASH_ATTN +//#define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 16 + +#define WHISPER_USE_SCRATCH +#define WHISPER_MAX_SCRATCH_BUFFERS 16 + +// available whisper models +enum e_model { + MODEL_UNKNOWN, + MODEL_TINY, + MODEL_BASE, + MODEL_SMALL, + MODEL_MEDIUM, + MODEL_LARGE, +}; + +static const std::map> g_lang = { + { "en", { 0, "english", } }, + { "zh", { 1, "chinese", } }, + { "de", { 2, "german", } }, + { "es", { 3, "spanish", } }, + { "ru", { 4, "russian", } }, + { "ko", { 5, "korean", } }, + { "fr", { 6, "french", } }, + { "ja", { 7, "japanese", } }, + { "pt", { 8, "portuguese", } }, + { "tr", { 9, "turkish", } }, + { "pl", { 10, "polish", } }, + { "ca", { 11, "catalan", } }, + { "nl", { 12, "dutch", } }, + { "ar", { 13, "arabic", } }, + { "sv", { 14, "swedish", } }, + { "it", { 15, "italian", } }, + { "id", { 16, "indonesian", } }, + { "hi", { 17, "hindi", } }, + { "fi", { 18, "finnish", } }, + { "vi", { 19, "vietnamese", } }, + { "iw", { 20, "hebrew", } }, + { "uk", { 21, "ukrainian", } }, + { "el", { 22, "greek", } }, + { "ms", { 23, "malay", } }, + { "cs", { 24, "czech", } }, + { "ro", { 25, "romanian", } }, + { "da", { 26, "danish", } }, + { "hu", { 27, "hungarian", } }, + { "ta", { 28, "tamil", } }, + { "no", { 29, "norwegian", } }, + { "th", { 30, "thai", } }, + { "ur", { 31, "urdu", } }, + { "hr", { 32, "croatian", } }, + { "bg", { 33, "bulgarian", } }, + { "lt", { 34, "lithuanian", } }, + { "la", { 35, "latin", } }, + { "mi", { 36, "maori", } }, + { "ml", { 37, "malayalam", } }, + { "cy", { 38, "welsh", } }, + { "sk", { 39, "slovak", } }, + { "te", { 40, "telugu", } }, + { "fa", { 41, "persian", } }, + { "lv", { 42, "latvian", } }, + { "bn", { 43, "bengali", } }, + { "sr", { 44, "serbian", } }, + { "az", { 45, "azerbaijani", } }, + { "sl", { 46, "slovenian", } }, + { "kn", { 47, "kannada", } }, + { "et", { 48, "estonian", } }, + { "mk", { 49, "macedonian", } }, + { "br", { 50, "breton", } }, + { "eu", { 51, "basque", } }, + { "is", { 52, "icelandic", } }, + { "hy", { 53, "armenian", } }, + { "ne", { 54, "nepali", } }, + { "mn", { 55, "mongolian", } }, + { "bs", { 56, "bosnian", } }, + { "kk", { 57, "kazakh", } }, + { "sq", { 58, "albanian", } }, + { "sw", { 59, "swahili", } }, + { "gl", { 60, "galician", } }, + { "mr", { 61, "marathi", } }, + { "pa", { 62, "punjabi", } }, + { "si", { 63, "sinhala", } }, + { "km", { 64, "khmer", } }, + { "sn", { 65, "shona", } }, + { "yo", { 66, "yoruba", } }, + { "so", { 67, "somali", } }, + { "af", { 68, "afrikaans", } }, + { "oc", { 69, "occitan", } }, + { "ka", { 70, "georgian", } }, + { "be", { 71, "belarusian", } }, + { "tg", { 72, "tajik", } }, + { "sd", { 73, "sindhi", } }, + { "gu", { 74, "gujarati", } }, + { "am", { 75, "amharic", } }, + { "yi", { 76, "yiddish", } }, + { "lo", { 77, "lao", } }, + { "uz", { 78, "uzbek", } }, + { "fo", { 79, "faroese", } }, + { "ht", { 80, "haitian creole", } }, + { "ps", { 81, "pashto", } }, + { "tk", { 82, "turkmen", } }, + { "nn", { 83, "nynorsk", } }, + { "mt", { 84, "maltese", } }, + { "sa", { 85, "sanskrit", } }, + { "lb", { 86, "luxembourgish", } }, + { "my", { 87, "myanmar", } }, + { "bo", { 88, "tibetan", } }, + { "tl", { 89, "tagalog", } }, + { "mg", { 90, "malagasy", } }, + { "as", { 91, "assamese", } }, + { "tt", { 92, "tatar", } }, + { "haw", { 93, "hawaiian", } }, + { "ln", { 94, "lingala", } }, + { "ha", { 95, "hausa", } }, + { "ba", { 96, "bashkir", } }, + { "jw", { 97, "javanese", } }, + { "su", { 98, "sundanese", } }, +}; + +static const size_t MB = 1024*1024; + +static const std::map MEM_REQ_SCRATCH0 = { + { MODEL_TINY, 12ull*MB }, + { MODEL_BASE, 15ull*MB }, + { MODEL_SMALL, 23ull*MB }, + { MODEL_MEDIUM, 31ull*MB }, + { MODEL_LARGE, 38ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH1 = { + { MODEL_TINY, 18ull*MB }, + { MODEL_BASE, 24ull*MB }, + { MODEL_SMALL, 36ull*MB }, + { MODEL_MEDIUM, 48ull*MB }, + { MODEL_LARGE, 60ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH2 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH3 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + +static const std::map MEM_REQ_MODEL = { + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, +}; + +static const std::map MEM_REQ_KV_SELF = { + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 6ull*MB }, + { MODEL_SMALL, 16ull*MB }, + { MODEL_MEDIUM, 43ull*MB }, + { MODEL_LARGE, 71ull*MB }, +}; + +static const std::map MEM_REQ_KV_CROSS = { + { MODEL_TINY, 9ull*MB }, + { MODEL_BASE, 18ull*MB }, + { MODEL_SMALL, 53ull*MB }, + { MODEL_MEDIUM, 141ull*MB }, + { MODEL_LARGE, 235ull*MB }, +}; + +static const std::map MEM_REQ_ENCODE = { + { MODEL_TINY, 6ull*MB }, + { MODEL_BASE, 8ull*MB }, + { MODEL_SMALL, 13ull*MB }, + { MODEL_MEDIUM, 22ull*MB }, + { MODEL_LARGE, 33ull*MB }, +}; + +static const std::map MEM_REQ_DECODE = { + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 5ull*MB }, + { MODEL_SMALL, 10ull*MB }, + { MODEL_MEDIUM, 18ull*MB }, + { MODEL_LARGE, 27ull*MB }, +}; + +struct whisper_mel { + int n_len; + int n_mel; + + std::vector data; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +struct whisper_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 51864; + + std::map token_to_id; + std::map id_to_token; + + id token_eot = 50256; + id token_sot = 50257; + id token_prev = 50360; + id token_solm = 50361; // ?? + id token_not = 50362; // no timestamps + id token_beg = 50363; + + // available tasks + static const id token_translate = 50358; + static const id token_transcribe = 50359; + + bool is_multilingual() const { + return n_vocab == 51865; + } +}; + +struct whisper_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector tokens; +}; + +// medium +// hparams: { +// 'n_mels': 80, +// 'n_vocab': 51864, +// 'n_audio_ctx': 1500, +// 'n_audio_state': 1024, +// 'n_audio_head': 16, +// 'n_audio_layer': 24, +// 'n_text_ctx': 448, +// 'n_text_state': 1024, +// 'n_text_head': 16, +// 'n_text_layer': 24 +// } +// +// default hparams (Whisper tiny) +struct whisper_hparams { + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; + int32_t n_audio_state = 384; + int32_t n_audio_head = 6; + int32_t n_audio_layer = 4; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t f16 = 1; +}; + +// audio encoding layer +struct whisper_layer_encoder { + // encoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // encoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // encoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // encoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // encoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // encoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // encoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // encoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +// token decoding layer +struct whisper_layer_decoder { + // decoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // decoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // decoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // decoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // decoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // decoder.blocks.*.cross_attn_ln + struct ggml_tensor * cross_attn_ln_0_w; + struct ggml_tensor * cross_attn_ln_0_b; + + // decoder.blocks.*.cross_attn.out + struct ggml_tensor * cross_attn_ln_1_w; + struct ggml_tensor * cross_attn_ln_1_b; + + // decoder.blocks.*.cross_attn.query + struct ggml_tensor * cross_attn_q_w; + struct ggml_tensor * cross_attn_q_b; + + // decoder.blocks.*.cross_attn.key + struct ggml_tensor * cross_attn_k_w; + + // decoder.blocks.*.cross_attn.value + struct ggml_tensor * cross_attn_v_w; + struct ggml_tensor * cross_attn_v_b; + + // decoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // decoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // decoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +struct whisper_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx; + + std::vector buf; + + int n; // number of tokens currently in the cache +}; + +struct whisper_model { + e_model type = MODEL_UNKNOWN; + + whisper_hparams hparams; + whisper_filters filters; + + // encoder.positional_embedding + struct ggml_tensor * e_pe; + + // encoder.conv1 + struct ggml_tensor * e_conv_1_w; + struct ggml_tensor * e_conv_1_b; + + // encoder.conv2 + struct ggml_tensor * e_conv_2_w; + struct ggml_tensor * e_conv_2_b; + + // encoder.ln_post + struct ggml_tensor * e_ln_w; + struct ggml_tensor * e_ln_b; + + // decoder.positional_embedding + struct ggml_tensor * d_pe; + + // decoder.token_embedding + struct ggml_tensor * d_te; + + // decoder.ln + struct ggml_tensor * d_ln_w; + struct ggml_tensor * d_ln_b; + + std::vector layers_encoder; + std::vector layers_decoder; + + // context + struct ggml_context * ctx; + + // the model memory buffer is read-only and can be shared between processors + std::vector * buf; + + // tensors + int n_loaded; + std::map tensors; +}; + +struct whisper_sequence { + std::vector tokens; + + // the accumulated transcription in the current interation (used to truncate the tokens array) + int result_len; + + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score +}; + +// TAGS: WHISPER_DECODER_INIT +struct whisper_decoder { + // each decoders keeps its own KV-cache + whisper_kv_cache kv_self; + + // the currently generated sequence of tokens + whisper_sequence sequence; + + int seek_delta; // the window shift found so far based on the decoded timestamp tokens + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; + std::vector logits; + std::vector logprobs; + + std::vector tokens_tmp; // used for whisper_decode calls +}; + +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_mel_us = 0; + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_start_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + ggml_type wtype; // weight type (FP32 or FP16) + + whisper_mel mel; + + whisper_model model; + whisper_vocab vocab; + + // cross-attention KV cache for the decoders + // shared between all decoders + whisper_kv_cache kv_cross; + + whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; + + int buf_last = 0; + size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; + + // decode output (2-dimensional array: [n_tokens][n_vocab]) + std::vector logits; + + std::vector result_all; + std::vector prompt_past; + + // work container used to avoid memory allocations + std::vector> logits_id; + + mutable std::mt19937 rng; // used for sampling at t > 0.0 + + int lang_id; + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg; + int64_t t_last; + whisper_token tid_last; + std::vector energy; // PCM signal energy + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx; // 0 - use default + + void use_buf(struct ggml_context * ctx, int i) { +#if defined(WHISPER_USE_SCRATCH) + size_t last_size = 0; + + if (i == -1) { + last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); + } else { + auto & buf = buf_scratch[i]; + last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); + } + + if (buf_last >= 0) { + buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + } + + buf_last = i; +#else + (void) i; + (void) ctx; +#endif + } + + size_t get_buf_max_mem(int i) const { +#if defined(WHISPER_USE_SCRATCH) + return buf_max_size[i]; +#else + (void) i; + return 0; +#endif + } +}; + +template +static void read_safe(whisper_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool kv_cache_init( + const struct whisper_hparams & hparams, + const size_t mem_bytes, + struct whisper_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + cache.buf.resize(mem_bytes); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mem = n_text_layer*n_ctx; + const int n_elements = n_text_state*n_mem; + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static bool kv_cache_reinit(struct whisper_kv_cache & cache) { + WHISPER_ASSERT(cache.ctx); + + const int n_elements = ggml_nelements(cache.k); + WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); + + const ggml_type wtype = cache.k->type; + WHISPER_ASSERT(wtype == cache.v->type); + + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + cache.ctx = nullptr; + } +} + +// load the model from a ggml file +// +// file format: +// +// - hparams +// - pre-computed mel filters +// - vocab +// - weights +// +// see the convert-pt-to-ggml.py script for details +// +static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { + fprintf(stderr, "%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + { + auto & hparams = model.hparams; + + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_text_ctx); + read_safe(loader, hparams.n_text_state); + read_safe(loader, hparams.n_text_head); + read_safe(loader, hparams.n_text_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.f16); + + assert(hparams.n_text_state == hparams.n_audio_state); + + if (hparams.n_audio_layer == 4) { + model.type = e_model::MODEL_TINY; + } + + if (hparams.n_audio_layer == 6) { + model.type = e_model::MODEL_BASE; + } + + if (hparams.n_audio_layer == 12) { + model.type = e_model::MODEL_SMALL; + } + + if (hparams.n_audio_layer == 24) { + model.type = e_model::MODEL_MEDIUM; + } + + if (hparams.n_audio_layer == 32) { + model.type = e_model::MODEL_LARGE; + } + + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const size_t scale = model.hparams.f16 ? 1 : 2; + + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state); + fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); + fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); + fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); + fprintf(stderr, "%s: type = %d\n", __func__, model.type); + + // print memory requirements + { + // this is the total memory required to run the inference + const size_t mem_required = + MEM_REQ_SCRATCH0.at (model.type) + + MEM_REQ_SCRATCH1.at (model.type) + + MEM_REQ_SCRATCH2.at (model.type) + + MEM_REQ_SCRATCH3.at (model.type) + + scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_KV_CROSS.at(model.type) + + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); + + // this is the memory required by one decoder + const size_t mem_required_decoder = + scale*MEM_REQ_KV_SELF.at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + } + + // initialize all memory buffers + // always have at least one decoder + + wctx.model.buf = new std::vector(); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + + wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); + wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); + wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); + wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fft); + + filters.data.resize(filters.n_mel * filters.n_fft); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + //if (n_vocab != model.hparams.n_vocab) { + // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + //} + + std::string word; + std::vector tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + // seems like we have an empty-string token in multi-language models (i = 50256) + //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + } + + vocab.n_vocab = model.hparams.n_vocab; + if (vocab.is_multilingual()) { + vocab.token_eot++; + vocab.token_sot++; + vocab.token_prev++; + vocab.token_solm++; + vocab.token_not++; + vocab.token_beg++; + } + + if (n_vocab < model.hparams.n_vocab) { + fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + for (int i = n_vocab; i < model.hparams.n_vocab; i++) { + if (i > vocab.token_beg) { + word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; + } else if (i == vocab.token_eot) { + word = "[_EOT_]"; + } else if (i == vocab.token_sot) { + word = "[_SOT_]"; + } else if (i == vocab.token_prev) { + word = "[_PREV_]"; + } else if (i == vocab.token_not) { + word = "[_NOT_]"; + } else if (i == vocab.token_beg) { + word = "[_BEG_]"; + } else { + word = "[_extra_token_" + std::to_string(i) + "]"; + } + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + + wctx.logits_id.reserve(n_vocab); + + // TAGS: WHISPER_DECODER_INIT + wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); + + wctx.decoders[0].probs.reserve (vocab.n_vocab); + wctx.decoders[0].logits.reserve (vocab.n_vocab); + wctx.decoders[0].logprobs.reserve(vocab.n_vocab); + } + + size_t ctx_size = 0; + + const ggml_type wtype = wctx.wtype; + + { + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + // encoder + { + ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; + + ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b + + ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b + + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; + } + + // decoder + { + ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; + + ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; + + ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; + ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; + } + + // encoder layers + { + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + } + + // decoder layers + { + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + // + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b + } + + ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead + + fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params; + params.mem_size = wctx.model.buf->size(); + params.mem_buffer = wctx.model.buf->data(); + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + auto & ctx = model.ctx; + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + model.layers_encoder.resize(n_audio_layer); + model.layers_decoder.resize(n_text_layer); + + // encoder + { + model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); + + model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.positional_embedding"] = model.e_pe; + + model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; + model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; + + model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; + model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; + + model.tensors["encoder.ln_post.weight"] = model.e_ln_w; + model.tensors["encoder.ln_post.bias"] = model.e_ln_b; + + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers_encoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + } + } + + // decoder + { + model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx); + + model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); + + model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.positional_embedding"] = model.d_pe; + + model.tensors["decoder.token_embedding.weight"] = model.d_te; + + model.tensors["decoder.ln.weight"] = model.d_ln_w; + model.tensors["decoder.ln.bias"] = model.d_ln_b; + + for (int i = 0; i < n_text_layer; ++i) { + auto & layer = model.layers_decoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; + } + } + } + + // load weights + { + size_t total_size = 0; + + model.n_loaded = 0; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ftype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[3] = { 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (model.tensors.find(name) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); + + if (nelements*bpe != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + + if (model.n_loaded == 0) { + fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } + } + + wctx.rng = std::mt19937(0); + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// evaluate the encoder +// +// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder +// part of the transformer model and returns the encoded features +// +// - model: the model +// - n_threads: number of threads to use +// - mel_offset: offset in the mel spectrogram (i.e. audio offset) +// +static bool whisper_encode( + whisper_context & wctx, + const int mel_offset, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = wctx.model; + const auto & mel_inp = wctx.mel; + const auto & hparams = model.hparams; + + const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + const int n_mels = hparams.n_mels; + assert(mel_inp.n_mel == n_mels); + + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); + + struct ggml_context * ctx0 = ggml_init(params); + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + assert(mel->type == GGML_TYPE_F32); + { + float * dst = (float *) mel->data; + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + + for (int j = 0; j < mel_inp.n_mel; ++j) { + for (int i = i0; i < i1; ++i) { + dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + } + } + } + + struct ggml_tensor * cur; + + // convolution + gelu + { + wctx.use_buf(ctx0, 1); + + cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_1_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 0); + + cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_2_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + } + + wctx.use_buf(ctx0, 3); + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + + struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + + cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); + + // =================================================================== + + // original: + //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_encoder[il]; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpL); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); + + //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); + + // ------ + + wctx.use_buf(ctx0, 0); + +#ifdef WHISPER_USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); +#else + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + ); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); + + //struct ggml_tensor * V_trans = + // ggml_permute(ctx0, + // ggml_cpy(ctx0, + // Vcur, + // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + // 1, 2, 0, 3); + + //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 0, 2, 1, 3), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); +#endif + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + wctx.use_buf(ctx0, 1); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } + + // projection + { + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 2); + + // add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + wctx.use_buf(ctx0, 1); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + } + +#ifdef WHISPER_USE_FLASH_FF + wctx.use_buf(ctx0, 0); + + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + wctx.use_buf(ctx0, 0); + + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); + + wctx.use_buf(ctx0, 0); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + wctx.use_buf(ctx0, 0); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); +#endif + } + + wctx.use_buf(ctx0, 3); + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + // cur = ln_f_g*cur + ln_f_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.e_ln_w, cur), + cur), + ggml_repeat(ctx0, model.e_ln_b, cur)); + } + + wctx.use_buf(ctx0, -1); + + // run the computation + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + ggml_build_forward_expand(&gf, cur); + ggml_graph_compute (ctx0, &gf); + + //ggml_graph_print(&gf); + } + + // cur + //{ + // printf("ne0 = %d\n", cur->ne[0]); + // printf("ne1 = %d\n", cur->ne[1]); + // for (int i = 0; i < 10; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("... "); + // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("\n"); + //} + + // pre-compute cross-attention memory + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + // TODO: hack to disconnect the encoded features from the previous graph + cur->op = GGML_OP_NONE; + cur->src0 = nullptr; + cur->src1 = nullptr; + + for (int il = 0; il < model.hparams.n_text_layer; ++il) { + auto & layer = model.layers_decoder[il]; + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.cross_attn_v_b, + Vcross), + Vcross); + + wctx.use_buf(ctx0, -1); + + //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); + } + + ggml_graph_compute(ctx0, &gf); + //ggml_graph_print(&gf); + } + + //////////////////////////////////////////////////////////////////////////// + + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); + + ggml_free(ctx0); + + wctx.t_encode_us += ggml_time_us() - t_start_us; + wctx.n_encode++; + + return true; +} + +// evaluate the decoder +// +// given text prompt + audio features -> computes the logits for the next token +// +// - model: the model +// - n_threads: number of threads to use +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with +// +static bool whisper_decode( + whisper_context & wctx, + whisper_decoder & decoder, + const whisper_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + auto & kv_self = decoder.kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + + auto & logits_out = wctx.logits; + + const int n_vocab = hparams.n_vocab; + + const int n_ctx = hparams.n_text_ctx; + const int n_state = hparams.n_text_state; + const int n_head = hparams.n_text_head; + const int n_layer = hparams.n_text_layer; + + const int N = n_tokens; + const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + + //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); + + struct ggml_context * ctx0 = ggml_init(params); + + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, tokens, N*ggml_element_size(embd)); + + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } + + wctx.use_buf(ctx0, 3); + + // token encoding + position encoding + struct ggml_tensor * cur = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_decoder[il]; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpL); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); + + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); + + // store key and value to memory + { + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // ------ + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), + n_state/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + wctx.use_buf(ctx0, 1); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + wctx.use_buf(ctx0, 0); + + //struct ggml_tensor * KQ_scaled = + // ggml_scale(ctx0, + // KQ, + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * V_trans = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), + n_state/n_head, n_head, n_past + N), + 1, 2, 0, 3); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + } + + // projection + { + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 2); + + // add the input + struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here + + wctx.use_buf(ctx0, 1); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); + } + + // cross-attention + { + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.cross_attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.cross_attn_q_b, + Qcur), + Qcur); + + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // Kcross is already scaled + struct ggml_tensor * Kcross = + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), + n_state/n_head, n_head, M); + + struct ggml_tensor * Vcross = + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), + n_state/n_head, n_head, M); + + struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); + + // ------ + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); + + wctx.use_buf(ctx0, 0); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + //struct ggml_tensor * KQ_scaled = + // ggml_scale(ctx0, + // KQ, + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + // no masking for cross-attention + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_state, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + } + + // projection + { + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, + layer.cross_attn_ln_1_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 2); + + // add the input + cur = ggml_add(ctx0, cur, inpCA); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + wctx.use_buf(ctx0, 1); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + } + + wctx.use_buf(ctx0, 0); + + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); + + wctx.use_buf(ctx0, 0); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + wctx.use_buf(ctx0, 0); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 3); + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.d_ln_w, cur), + cur), + ggml_repeat(ctx0, model.d_ln_b, cur)); + } + + wctx.use_buf(ctx0, 0); + + // compute logits only for the last token + // comment this line to compute logits for all N tokens + // might be useful in the future + cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + + wctx.use_buf(ctx0, -1); + + // run the computation + { + ggml_build_forward_expand(&gf, logits); + ggml_graph_compute (ctx0, &gf); + } + + // extract logits for all N tokens + //logits_out.resize(N*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + + // extract logits only for the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + + if (N > 1) { + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); + } + + ggml_free(ctx0); + + wctx.t_decode_us += ggml_time_us() - t_start_us; + wctx.n_decode++; + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +static std::string to_timestamp(int64_t t, bool comma = false) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const std::vector & in, std::vector & out) { + int N = in.size(); + + out.resize(N*2); + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + float angle = 2*M_PI*k*n/N; + re += in[n]*cos(angle); + im -= in[n]*sin(angle); + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(const std::vector & in, std::vector & out) { + out.resize(in.size()*2); + + int N = in.size(); + + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + if (N%2 == 1) { + dft(in, out); + return; + } + + std::vector even; + std::vector odd; + + even.reserve(N/2); + odd.reserve(N/2); + + for (int i = 0; i < N; i++) { + if (i % 2 == 0) { + even.push_back(in[i]); + } else { + odd.push_back(in[i]); + } + } + + std::vector even_fft; + std::vector odd_fft; + + fft(even, even_fft); + fft(odd, odd_fft); + + for (int k = 0; k < N/2; k++) { + float theta = 2*M_PI*k/N; + + float re = cos(theta); + float im = -sin(theta); + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 +static bool log_mel_spectrogram( + whisper_context & wctx, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int fft_size, + const int fft_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool speed_up, + whisper_mel & mel) { + const int64_t t_start_us = ggml_time_us(); + + // Hanning window + std::vector hann; + hann.resize(fft_size); + for (int i = 0; i < fft_size; i++) { + hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); + } + + mel.n_mel = n_mel; + mel.n_len = (n_samples)/fft_step; + mel.data.resize(mel.n_mel*mel.n_len); + + const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); + + //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); + //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); + + std::vector workers(n_threads); + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw] = std::thread([&](int ith) { + std::vector fft_in; + fft_in.resize(fft_size); + for (int i = 0; i < fft_size; i++) { + fft_in[i] = 0.0; + } + + std::vector fft_out; + fft_out.resize(2*fft_size); + + for (int i = ith; i < mel.n_len; i += n_threads) { + const int offset = i*fft_step; + + // apply Hanning window + for (int j = 0; j < fft_size; j++) { + if (offset + j < n_samples) { + fft_in[j] = hann[j]*samples[offset + j]; + } else { + fft_in[j] = 0.0; + } + } + + // FFT -> mag^2 + fft(fft_in, fft_out); + + for (int j = 0; j < fft_size; j++) { + fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); + } + for (int j = 1; j < fft_size/2; j++) { + //if (i == 0) { + // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); + //} + fft_out[j] += fft_out[fft_size - j]; + } + if (i == 0) { + //for (int j = 0; j < fft_size; j++) { + // printf("%d: %e\n", j, fft_out[j]); + //} + } + + if (speed_up) { + // scale down in the frequency domain results in a speed up in the time domain + for (int j = 0; j < n_fft; j++) { + fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); + } + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + + for (int k = 0; k < n_fft; k++) { + sum += fft_out[k]*filters.data[j*n_fft + k]; + } + if (sum < 1e-10) { + sum = 1e-10; + } + + sum = log10(sum); + + mel.data[j*mel.n_len + i] = sum; + } + } + }, iw); + } + + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw].join(); + } + + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] > mmax) { + mmax = mel.data[i]; + } + } + //printf("%s: max = %f\n", __func__, mmax); + + mmax -= 8.0; + + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] < mmax) { + mel.data[i] = mmax; + } + + mel.data[i] = (mel.data[i] + 4.0)/4.0; + } + + wctx.t_mel_us += ggml_time_us() - t_start_us; + + return true; +} + +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.empty()) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + while (j > i) { + auto it = vocab.token_to_id.find(word.substr(i, j-i)); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + break; + } + --j; + } + if (i == n) { + break; + } + if (j == i) { + auto sub = word.substr(i, 1); + if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { + tokens.push_back(vocab.token_to_id.at(sub)); + } else { + fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); + } + ++i; + } + } + } + + return tokens; +} + +// +// interface implementation +// + +struct whisper_context * whisper_init_from_file(const char * path_model) { + whisper_model_loader loader = {}; + + fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); + + auto fin = std::ifstream(path_model, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + loader.context = &fin; + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + return whisper_init(&loader); +} + +struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; + whisper_model_loader loader = {}; + + fprintf(stderr, "%s: loading model from buffer\n", __func__); + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return whisper_init(&loader); +} + +struct whisper_context * whisper_init(struct whisper_model_loader * loader) { + ggml_time_init(); + + whisper_context * ctx = new whisper_context; + + if (!whisper_model_load(loader, *ctx)) { + loader->close(loader->context); + fprintf(stderr, "%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + return ctx; +} + +void whisper_free(struct whisper_context * ctx) { + if (ctx) { + if (ctx->model.ctx) { + ggml_free(ctx->model.ctx); + } + if (ctx->model.buf) { + delete ctx->model.buf; + } + if (ctx->kv_cross.ctx) { + ggml_free(ctx->kv_cross.ctx); + } + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + if (ctx->decoders[i].kv_self.ctx) { + ggml_free(ctx->decoders[i].kv_self.ctx); + } + } + delete ctx; + } +} + +int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel) { + if (n_mel != WHISPER_N_MEL) { + fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); + return -1; + } + + ctx->mel.n_len = n_len; + ctx->mel.n_mel = n_mel; + + ctx->mel.data.resize(n_len*n_mel); + memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { + if (!whisper_encode(*ctx, offset, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + // TODO: add selected_decoder_id to context + const int selected_decoder_id = 0; + + if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -1; + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int whisper_lang_max_id() { + auto max_id = 0; + for (const auto & kv : g_lang) { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + +int whisper_lang_id(const char * lang) { + if (!g_lang.count(lang)) { + for (const auto & kv : g_lang) { + if (kv.second.second == lang) { + return kv.second.first; + } + } + + fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); + return -1; + } + + return g_lang.at(lang).first; +} + +const char * whisper_lang_str(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.first.c_str(); + } + } + + fprintf(stderr, "%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + const int seek = offset_ms/10; + + if (seek < 0) { + fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + return -1; + } + + if (seek >= ctx->mel.n_len) { + fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + return -2; + } + + // run the encoder + if (whisper_encode(ctx, seek, n_threads) != 0) { + fprintf(stderr, "%s: failed to encode\n", __func__); + return -6; + } + + const std::vector prompt = { whisper_token_sot(ctx) }; + + if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + auto & logits_id = ctx->logits_id; + logits_id.clear(); + + for (const auto & kv : g_lang) { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); + } + + // sort descending + { + using pair_type = std::remove_reference::type::value_type; + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // softmax + { + const auto max = logits_id[0].first; + + double sum = 0.0f; + for (auto & kv : logits_id) { + kv.first = exp(kv.first - max); + sum += kv.first; + } + + for (auto & kv : logits_id) { + kv.first /= sum; + } + } + + { + for (const auto & prob : logits_id) { + if (lang_probs) { + lang_probs[prob.second] = prob.first; + } + + //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); + } + } + + return logits_id[0].second; +} + +int whisper_n_len(struct whisper_context * ctx) { + return ctx->mel.n_len; +} + +int whisper_n_vocab(struct whisper_context * ctx) { + return ctx->vocab.n_vocab; +} + +int whisper_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_is_multilingual(struct whisper_context * ctx) { + return ctx->vocab.is_multilingual() ? 1 : 0; +} + +float * whisper_get_logits(struct whisper_context * ctx) { + return ctx->logits.data(); +} + +const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +whisper_token whisper_token_eot(struct whisper_context * ctx) { + return ctx->vocab.token_eot; +} + +whisper_token whisper_token_sot(struct whisper_context * ctx) { + return ctx->vocab.token_sot; +} + +whisper_token whisper_token_prev(struct whisper_context * ctx) { + return ctx->vocab.token_prev; +} + +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; +} + +whisper_token whisper_token_not(struct whisper_context * ctx) { + return ctx->vocab.token_not; +} + +whisper_token whisper_token_beg(struct whisper_context * ctx) { + return ctx->vocab.token_beg; +} + +whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { + return whisper_token_sot(ctx) + 1 + lang_id; +} + +whisper_token whisper_token_translate(void) { + return whisper_vocab::token_translate; +} + +whisper_token whisper_token_transcribe(void) { + return whisper_vocab::token_transcribe; +} + +void whisper_print_timings(struct whisper_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + const int32_t n_sample = std::max(1, ctx->n_sample); + const int32_t n_encode = std::max(1, ctx->n_encode); + const int32_t n_decode = std::max(1, ctx->n_decode); + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); + fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample); + fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode); + fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void whisper_reset_timings(struct whisper_context * ctx) { + ctx->t_sample_us = 0; + ctx->t_encode_us = 0; + ctx->t_decode_us = 0; +} + +const char * whisper_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +//////////////////////////////////////////////////////////////////////////// + +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { + struct whisper_full_params result = { + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.split_on_word =*/ false, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + + /*.suppress_blank =*/ true, + /*.suppress_non_speech_tokens =*/true, + + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, + + /*.temperature_inc =*/ 0.2f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, + + /*.greedy =*/ { + /*.best_of =*/ -1, + }, + + /*.beam_search =*/ { + /*.beam_size =*/ -1, + + /*.patience =*/ -1.0f, + }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + }; + + switch (strategy) { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/ 1, + }; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + result.beam_search = { + /*.beam_size =*/ 5, + + /*.patience =*/ -1.0f, + }; + } break; + } + + return result; +} + +// forward declarations +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum); + +// trim from start (in place) +static inline void ltrim(std::string &s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch); + })); +} + +// trim from end (in place) +static inline void rtrim(std::string &s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { + return !std::isspace(ch); + }).base(), s.end()); +} + +// trim from both ends (in place) +static inline void trim(std::string &s) { + rtrim(s); + ltrim(s); +} + +static inline bool should_split_on_word(const char * txt, bool split_on_word) { + if (!split_on_word) return true; + + return txt[0] == ' '; +} + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) { + auto segment = ctx.result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(&ctx, token.id); + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { + // split here + if (split_on_word) { + trim(text); + } + + ctx.result_all.back().text = std::move(text); + ctx.result_all.back().t1 = token.t0; + ctx.result_all.back().tokens.resize(i); + + ctx.result_all.push_back({}); + ctx.result_all.back().t0 = token.t0; + ctx.result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + ctx.result_all.back().tokens.insert( + ctx.result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + acc = 0; + text = ""; + + segment = ctx.result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + if (split_on_word) { + trim(text); + } + ctx.result_all.back().text = std::move(text); + + return res; +} + +static const std::vector non_speech_tokens +{ + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" +}; + +// process the logits for the selected decoder +// - applies logit filters +// - computes logprobs and probs +static void whisper_process_logits( + const struct whisper_context & ctx, + const struct whisper_full_params params, + struct whisper_decoder & decoder, + float temperature) { + const auto & vocab = ctx.vocab; + const auto & tokens_cur = decoder.sequence.tokens; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + auto & probs = decoder.probs; + auto & logits = decoder.logits; + auto & logprobs = decoder.logprobs; + { + logits.resize(n_logits); + memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); + + if (temperature > 0.0f) { + for (int i = 0; i < n_logits; i++) { + logits[i] /= temperature; + } + } + + // will be populated a bit later + probs.resize(n_logits); + logprobs.resize(n_logits); + } + + // apply logit filters here + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) { + if (is_initial) { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } + } + + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; + + // suppress sot and solm tokens + logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_solm] = -INFINITY; + + // suppress task tokens + logits[vocab.token_translate] = -INFINITY; + logits[vocab.token_transcribe] = -INFINITY; + + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + if (params.suppress_non_speech_tokens) + { + for (const std::string &token : non_speech_tokens) + { + std::string suppress_tokens[] = {token, " " + token}; + for (const std::string &suppress_token : suppress_tokens) + { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + } + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } else { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_ts + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_ts > 0.0f) { + const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_ts/precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } + + // condition timestamp tokens to be increasing + // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 + if (decoder.has_ts) { + const int tid0 = decoder.seek_delta/2; + + for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++i) { + if (logprobs[i] > -INFINITY) { + logsumexp += expf(logprobs[i] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + + if (timestamp_logprob > max_text_token_logprob) { + for (int i = 0; i < vocab.token_beg; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + } + } + } + + // compute probs + { + for (int i = 0; i < n_logits; ++i) { + if (logits[i] == -INFINITY) { + probs[i] = 0.0f; + } else { + probs[i] = expf(logprobs[i]); + } + } + } + +#if 0 + // print first 100 logits - token string : logit + for (int i = 0; i < 100; i++) { + const auto token = vocab.id_to_token.at(i); + const auto prob = probs[i]; + const auto logit = logits[i]; + const auto logprob = logprobs[i]; + printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + } + + // "And", "and", " And", " and" + printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); +#endif +} + +static whisper_token_data whisper_sample_token( + whisper_context & ctx, + const whisper_decoder & decoder, + bool best) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; + + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + if (best) { + for (int i = 0; i < n_logits; ++i) { + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(ctx.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; + } + + if (result.id >= vocab.token_beg) { + result.tid = result.id; + result.pt = result.p; + } + + ctx.n_sample++; + + return result; +} + +static std::vector whisper_sample_token_topk( + whisper_context & ctx, + const whisper_decoder & decoder, + int k) { + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logits = decoder.logits; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto & logits_id = ctx.logits_id; + + logits_id.clear(); + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back({ logits[i], i }); + } + + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + std::vector result; + result.reserve(k); + + whisper_token tid = vocab.token_beg; + + float pt = 0.0; + float ptsum = 0.0; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + tid = i; + } + } + + pt = max_ts/(sum_ts + 1e-10); + ptsum = sum_ts; + } + + for (int i = 0; i < k; ++i) { + const auto id = logits_id[i].second; + + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); + + if (result[i].id >= vocab.token_beg) { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + ctx.n_sample++; + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params & params, + whisper_sequence & sequence) { + if (sequence.result_len == 0) { + return; + } + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result/sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) { + penalty = pow((5.0 + penalty)/6.0, params.length_penalty); + } + + sequence.score = result/penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto & kv : token_counts) { + const auto p = kv.second/(double)cnt; + entropy -= p*log(p); + + //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } +} + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = ctx->result_all; + + result_all.clear(); + + // compute log mel spectrogram + if (params.speed_up) { + if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } + } else { + if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + fprintf(stderr, "%s: failed to auto-detect language\n", __func__); + return -3; + } + ctx->lang_id = lang_id; + params.language = whisper_lang_str(lang_id); + + fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + } + + if (params.token_timestamps) { + ctx->t_beg = 0; + ctx->t_last = 0; + ctx->tid_last = 0; + ctx->energy = get_signal_energy(samples, n_samples, 32); + } + + const int seek_start = params.offset_ms/10; + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); + + // if length of spectrogram is less than 1s (100 samples), then return + // basically don't process anything that is less than 1s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { + return 0; + } + + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_inc > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { + temperatures.push_back(t); + } + } else { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } break; + }; + + n_decoders = std::max(1, n_decoders); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) { + auto & decoder = ctx->decoders[j]; + + if (decoder.kv_self.ctx == nullptr) { + decoder.kv_self = ctx->decoders[0].kv_self; + if (!kv_cache_reinit(decoder.kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + return -4; + } + + WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + + decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + } + } + + // the accumulated text context so far + auto & prompt_past = ctx->prompt_past; + if (params.no_context) { + prompt_past.clear(); + } + + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -5; + } + ctx->exp_n_audio_ctx = params.audio_ctx; + + // these tokens determine the task that will be performed + std::vector prompt_init = { whisper_token_sot(ctx) }; + if (whisper_is_multilingual(ctx)) { + const int lang_id = whisper_lang_id(params.language); + ctx->lang_id = lang_id; + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); + if (params.translate) { + prompt_init.push_back(whisper_token_translate()); + } else { + prompt_init.push_back(whisper_token_transcribe()); + } + } + + int progress_prev = 0; + int progress_step = 5; + + int seek = seek_start; + + std::vector prompt; + prompt.reserve(whisper_n_text_ctx(ctx)); + + // beam-search helpers + struct kv_buf { + std::vector k; + std::vector v; + }; + + std::vector kv_bufs; + + struct beam_candidate { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + }; + + std::vector beam_candidates; + + // main loop + while (true) { + const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + while (progress_cur >= progress_prev + progress_step) { + progress_prev += progress_step; + if (params.print_progress) { + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev); + } + } + + // of only 1 second left, then stop + if (seek + 100 >= seek_end) { + break; + } + + if (params.encoder_begin_callback) { + if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + + // encode audio features starting at offset seek + if (!whisper_encode(*ctx, seek, params.n_threads)) { + fprintf(stderr, "%s: failed to encode\n", __func__); + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); + } + + int best_decoder_id = 0; + + for (int it = 0; it < (int) temperatures.size(); ++it) { + const float t_cur = temperatures[it]; + + int n_decoders_cur = 1; + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } else { + n_decoders_cur = params.beam_search.beam_size; + } + } break; + }; + + n_decoders_cur = std::max(1, n_decoders_cur); + + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + decoder.kv_self.n = 0; + + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; + + decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; + + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + } + + // init prompt and kv cache for the current iteration + // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + { + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + } + + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + // print the prompt + WHISPER_PRINT_DEBUG("\n\n"); + for (int i = 0; i < (int) prompt.size(); i++) { + WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + } + WHISPER_PRINT_DEBUG("\n\n"); + + if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); + + ctx->decoders[0].kv_self.n += prompt.size(); + + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + + decoder.kv_self.n += prompt.size(); + + memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + const int64_t t_start_sample_us = ggml_time_us(); + + // store the KV caches of all decoders when doing beam-search + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + kv_bufs.resize(n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k)); + kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v)); + + memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); + memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); + } + + beam_candidates.clear(); + } + + // generate new sequence candidates for each decoder + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); + beam_candidates.back().sequence.tokens.push_back(token); + beam_candidates.back().sequence.sum_logprobs_all += token.plog; + + //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); + } + } break; + }; + } + + // for beam-search, choose the top candidates and update the KV caches + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate & a, const beam_candidate & b) { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + }); + + int cur_c = 0; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & cur = beam_candidates[cur_c++]; + + while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { + ++cur_c; + } + + decoder.sequence = cur.sequence; + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; + + memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); + memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); + + WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } + } + + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & has_ts = decoder.has_ts; + auto & failed = decoder.failed; + auto & completed = decoder.completed; + auto & seek_delta = decoder.seek_delta; + auto & result_len = decoder.sequence.result_len; + + { + const auto & token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + failed = true; // TODO: maybe this is not a failure ? + continue; + } + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif + + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + failed = true; + continue; + } + } + + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + + completed = true; + continue; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + completed = true; + continue; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; + continue; + } + } + + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + completed_all = false; + } + + if (completed_all) { + break; + } + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + + // obtain logits for the next token + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + decoder.tokens_tmp.resize(1); + decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); + + if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -8; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + whisper_process_logits(*ctx, params, decoder, t_cur); + + ++decoder.kv_self.n; + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + } + + // rank the resulting sequences and select the best one + { + double best_score = -INFINITY; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed) { + continue; + } + + decoder.sequence.tokens.resize(decoder.sequence.result_len); + whisper_sequence_score(params, decoder.sequence); + + WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + + if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { + WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_thold); + + decoder.failed = true; + ctx->n_fail_h++; + + continue; + } + + if (best_score < decoder.sequence.score) { + best_score = decoder.sequence.score; + best_decoder_id = j; + } + } + + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } + + // was the decoding successful for the current temperature? + { + bool success = true; + + const auto & decoder = ctx->decoders[best_decoder_id]; + + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { + success = false; + ctx->n_fail_p++; + } + + if (success) { + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} + + break; + } + } + + WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + } + + // output results through a user-provided callback + { + const auto & best_decoder = ctx->decoders[best_decoder_id]; + + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; + + const auto & tokens_cur = best_decoder.sequence.tokens; + + //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + + // update prompt_past + prompt_past.clear(); + if (prompt.front() == whisper_token_prev(ctx)) { + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + } + + for (int i = 0; i < result_len; ++i) { + prompt_past.push_back(tokens_cur[i].id); + } + + // store the text from this iteration + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + + std::string text; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + } else { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; + } + i--; + t0 = t1; + i0 = i + 1; + } + } + + if (!text.empty()) { + const auto t1 = seek + seek_delta; + + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + } + + // update audio window + seek += seek_delta; + + WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + } + } + + return 0; +} + +int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors) { + if (n_processors == 1) { + return whisper_full(ctx, params, samples, n_samples); + } + + int ret = 0; + + // prepare separate contexts for each thread + std::vector ctxs(n_processors - 1); + + for (int i = 0; i < n_processors - 1; ++i) { + auto & ctx_p = ctxs[i]; + + ctx_p = *ctx; + + ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); + + ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); + + if (!kv_cache_reinit(ctx_p.kv_cross)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); + return false; + } + + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); + return false; + } + + ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); + + ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); + } + } + + const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) { + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) { + workers[i].join(); + } + + const int64_t offset_t = (int64_t) params.offset_ms/10.0; + + // combine results into ctx->result_all + for (int i = 0; i < n_processors - 1; ++i) { + auto & results_i = ctxs[i].result_all; + + for (auto & result : results_i) { + // correct the segment timestamp taking into account the offset + result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->result_all.empty()) { + result.t0 = std::max(result.t0, ctx->result_all.back().t1); + } + + ctx->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) { + params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); + } + } + + ctx->t_mel_us += ctxs[i].t_mel_us; + ctx->t_sample_us += ctxs[i].t_sample_us; + ctx->t_encode_us += ctxs[i].t_encode_us; + ctx->t_decode_us += ctxs[i].t_decode_us; + + kv_cache_free(ctx->kv_cross); + + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + kv_cache_free(ctx->decoders[j].kv_self); + } + } + + // average the timings + ctx->t_mel_us /= n_processors; + ctx->t_sample_us /= n_processors; + ctx->t_encode_us /= n_processors; + ctx->t_decode_us /= n_processors; + + // print information about the audio boundaries + fprintf(stderr, "\n"); + fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) { + fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + } + fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + +int whisper_full_n_segments(struct whisper_context * ctx) { + return ctx->result_all.size(); +} + +int whisper_full_lang_id(struct whisper_context * ctx) { + return ctx->lang_id; +} + +int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].t1; +} + +const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].p; +} + +// ================================================================================================= + +// +// Temporary interface needed for exposing ggml interface +// Will be removed in the future when ggml becomes a separate library +// + +WHISPER_API int whisper_bench_memcpy(int n_threads) { + ggml_time_init(); + + size_t n = 50; + size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations + + // 1 GB array + const size_t size = arr*1024llu*1024llu; + + char * src = (char *) malloc(size); + char * dst = (char *) malloc(size); + + for (size_t i = 0; i < size; i++) src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + for (size_t i = 0; i < n; i++) { + const int64_t t0 = ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + + src[0] = rand(); + } + + fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); + + // needed to prevent the compile from optimizing the memcpy away + { + double sum = 0.0; + + for (size_t i = 0; i < size; i++) sum += dst[i]; + + fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); + } + + free(src); + free(dst); + + return 0; +} + +WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { + ggml_time_init(); + + const int n_max = 128; + + const std::vector sizes = { + 64, 128, 256, 512, 1024, 2048, 4096, + }; + + const size_t N_max = sizes.back(); + + // a: N*N*sizeof(float) + // b: N*N*sizeof(float) + // c: N*N*sizeof(float) + // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) + std::vector buf(4llu*N_max*N_max*sizeof(float) + 4*256); + + for (size_t i = 0; i < buf.size(); i++) buf[i] = i; + + for (int j = 0; j < (int) sizes.size(); j++) { + int n_fp16 = 0; + int n_fp32 = 0; + + // GFLOPS/s + double s_fp16 = 0.0; + double s_fp32 = 0.0; + + const size_t N = sizes[j]; + + for (int k = 0; k < 2; ++k) { + const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + double & s = k == 0 ? s_fp16 : s_fp32; + int & n = k == 0 ? n_fp16 : n_fp32; + + struct ggml_init_params gparams = { + /*.mem_size =*/ buf.size(), + /*.mem_buffer =*/ buf.data(), + }; + + struct ggml_context * ctx0 = ggml_init(gparams); + + struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); + struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + + struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); + + struct ggml_cgraph gf = ggml_build_forward(c); + + gf.n_threads = n_threads; + + double tsum = 0.0; + + // heat-up + ggml_graph_compute(ctx0, &gf); + + for (int i = 0; i < n_max; ++i) { + const int64_t t0 = ggml_time_us(); + + ggml_graph_compute(ctx0, &gf); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + n++; + + if (tsum > 1.0 && n >= 3) { + break; + } + } + + ggml_free(ctx0); + + s = ((2.0*N*N*N*n)/tsum)*1e-9; + } + + fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", + N, N, s_fp16, n_fp16, s_fp32, n_fp32); + } + + return 0; +} + +// ================================================================================================= + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100ll*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (char c : text) { + if (c == ' ') { + res += 0.01f; + } else if (c == ',') { + res += 2.00f; + } else if (c == '.') { + res += 3.00f; + } else if (c == '!') { + res += 3.00f; + } else if (c == '?') { + res += 3.00f; + } else if (c >= '0' && c <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx.result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = ctx.energy.size(); + + if (n_samples == 0) { + fprintf(stderr, "%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = ctx.t_beg; + auto & t_last = ctx.t_last; + auto & tid_last = ctx.tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(&ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(&ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(&ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += ctx.energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (ctx.energy[k] > thold && j > 0) { + while (k > 0 && ctx.energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (ctx.energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (ctx.energy[k] > thold) { + while (k < n_samples - 1 && ctx.energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (ctx.energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(&ctx)) { + // continue; + // } + //} +} diff --git a/bindings/ruby/ext/whisper.h b/bindings/ruby/ext/whisper.h new file mode 100644 index 00000000000..7eece797c16 --- /dev/null +++ b/bindings/ruby/ext/whisper.h @@ -0,0 +1,379 @@ +#ifndef WHISPER_H +#define WHISPER_H + +#include +#include +#include + +#ifdef WHISPER_SHARED +# ifdef _WIN32 +# ifdef WHISPER_BUILD +# define WHISPER_API __declspec(dllexport) +# else +# define WHISPER_API __declspec(dllimport) +# endif +# else +# define WHISPER_API __attribute__ ((visibility ("default"))) +# endif +#else +# define WHISPER_API +#endif + +#define WHISPER_SAMPLE_RATE 16000 +#define WHISPER_N_FFT 400 +#define WHISPER_N_MEL 80 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // concurrently. + // + // Basic usage: + // + // #include "whisper.h" + // + // ... + // + // struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin"); + // + // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + // fprintf(stderr, "failed to process audio\n"); + // return 7; + // } + // + // const int n_segments = whisper_full_n_segments(ctx); + // for (int i = 0; i < n_segments; ++i) { + // const char * text = whisper_full_get_segment_text(ctx, i); + // printf("%s", text); + // } + // + // whisper_free(ctx); + // + // ... + // + // This is a demonstration of the most straightforward usage of the library. + // "pcmf32" contains the RAW audio data in 32-bit floating point format. + // + // The interface also allows for more fine-grained control over the computation, but it requires a deeper + // understanding of how the model works. + // + + struct whisper_context; + + typedef int whisper_token; + + typedef struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float plog; // log probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token + } whisper_token_data; + + typedef struct whisper_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } whisper_model_loader; + + // Various functions for loading a ggml whisper model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); + + // Frees all memory allocated by the model. + WHISPER_API void whisper_free(struct whisper_context * ctx); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + // The resulting spectrogram is stored inside the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel_phase_vocoder( + struct whisper_context* ctx, + const float* samples, + int n_samples, + int n_threads); + + + // This can be used to set a custom log mel spectrogram inside the provided whisper context. + // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 80 + // Returns 0 on success + WHISPER_API int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + WHISPER_API int whisper_encode( + struct whisper_context * ctx, + int offset, + int n_threads); + + // Run the Whisper decoder to obtain the logits and probabilities for the next token. + // Make sure to call whisper_encode() first. + // tokens + n_tokens is the provided context for the decoder. + // n_past is the number of tokens to use from previous decoder calls. + // Returns 0 on success + // TODO: add support for multiple decoders + WHISPER_API int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns -1 on failure + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(); + + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 + WHISPER_API int whisper_lang_id(const char * lang); + + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whispe_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); + + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); + + // Token Id -> String. Uses the vocabulary in the provided context + WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + + // Special tokens + WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + + // Task tokens + WHISPER_API whisper_token whisper_token_translate (void); + WHISPER_API whisper_token whisper_token_transcribe(void); + + // Performance information + WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + + // Print system information + WHISPER_API const char * whisper_print_system_info(void); + + //////////////////////////////////////////////////////////////////////////// + + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder + }; + + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + + // Parameters for the whisper_full() function + // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() + struct whisper_full_params { + enum whisper_sampling_strategy strategy; + + int n_threads; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool translate; + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool single_segment; // force single segment output (useful for streaming) + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime + + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + int audio_ctx; // overwrite the audio context size (0 = use default) + + // tokens to provide to the whisper decoder as initial prompt + // these are prepended to any existing text context from a previous call + const whisper_token * prompt_tokens; + int prompt_n_tokens; + + // for auto-detection, set to nullptr, "" or "auto" + const char * language; + + // common decoding parameters: + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + + // fallback parameters + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented + + struct { + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 + } greedy; + + struct { + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 + + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf + } beam_search; + + // called for every newly generated text segment + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called each time before the encoder starts + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + }; + + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Uses the specified decoding strategy to obtain the text. + WHISPER_API int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full() + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); + + // Number of generated text segments. + // A segment can be a few words, a sentence, or even a paragraph. + WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + + // Language id associated with the current context + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + + // Get the start and end time of the specified segment. + WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); + + // Get the text of the specified segment. + WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + + // Get number of tokens in the specified segment. + WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + + // Get the token text of the specified token in the specified segment. + WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment. + // This contains probabilities, timestamps, etc. + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment. + WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + + //////////////////////////////////////////////////////////////////////////// + + // Temporary helpers needed for exposing ggml interface + + WHISPER_API int whisper_bench_memcpy(int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb new file mode 100644 index 00000000000..82c21fe71f3 --- /dev/null +++ b/bindings/ruby/tests/test_whisper.rb @@ -0,0 +1,23 @@ +TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) +EXTDIR = File.join(TOPDIR, 'ext') +#$LIBDIR = File.join(TOPDIR, 'lib') +#$:.unshift(LIBDIR) +$:.unshift(EXTDIR) + +require 'whisper' +require 'test/unit' + +class TestWhisper < Test::Unit::TestCase + def setup + @params = Whisper::Params.new + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) + end + + def test_autodetect + @params.auto_detection = true + end + + def test_whisper + end + +end From 66482e326d71269142eda88ed3377ba33196c497 Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Tue, 14 Feb 2023 08:00:51 -0500 Subject: [PATCH 2/8] avoid adding these they are copied in via extconf.rb --- bindings/ruby/ext/ggml.c | 8616 --------------------------------- bindings/ruby/ext/ggml.h | 748 --- bindings/ruby/ext/whisper.cpp | 4814 ------------------ bindings/ruby/ext/whisper.h | 379 -- 4 files changed, 14557 deletions(-) delete mode 100644 bindings/ruby/ext/ggml.c delete mode 100644 bindings/ruby/ext/ggml.h delete mode 100644 bindings/ruby/ext/whisper.cpp delete mode 100644 bindings/ruby/ext/whisper.h diff --git a/bindings/ruby/ext/ggml.c b/bindings/ruby/ext/ggml.c deleted file mode 100644 index d67612c36a3..00000000000 --- a/bindings/ruby/ext/ggml.c +++ /dev/null @@ -1,8616 +0,0 @@ -#include "ggml.h" - -#if defined(_MSC_VER) || defined(__MINGW32__) -#include // using malloc.h with MSC/MINGW -#elif !defined(__FreeBSD__) -#include -#endif - -#include -#include -#include -#include -#include -#include -#include - -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef static_assert -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif - -#if defined _MSC_VER || defined(__MINGW32__) - -#if !defined(__MINGW32__) -#include -#else -// ref: https://github.com/ggerganov/whisper.cpp/issues/168 -#include -#include -#endif - -typedef volatile LONG atomic_int; -typedef atomic_int atomic_bool; - -static void atomic_store(atomic_int* ptr, LONG val) { - InterlockedExchange(ptr, val); -} -static LONG atomic_load(atomic_int* ptr) { - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { - return InterlockedExchangeAdd(ptr, inc); -} -static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { - return atomic_fetch_add(ptr, -(dec)); -} - -typedef HANDLE pthread_t; - -typedef DWORD thread_ret_t; -static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { - HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); - if (handle == NULL) - { - return EAGAIN; - } - - *out = handle; - return 0; -} - -static int pthread_join(pthread_t thread, void* unused) { - return (int) WaitForSingleObject(thread, INFINITE); -} - -static int sched_yield (void) { - Sleep (0); - return 0; -} -#else -#include -#include - -typedef void* thread_ret_t; -#endif - -#ifdef __HAIKU__ -#define static_assert(cond, msg) _Static_assert(cond, msg) -#endif - -/*#define GGML_PERF*/ -#define GGML_DEBUG 0 -#define GGML_GELU_FP16 - -#define GGML_SOFT_MAX_UNROLL 4 -#define GGML_VEC_DOT_UNROLL 2 - -#ifdef GGML_USE_ACCELERATE -// uncomment to use vDSP for soft max computation -// note: not sure if it is actually faster -//#define GGML_SOFT_MAX_ACCELERATE -#endif - -#if UINTPTR_MAX == 0xFFFFFFFF - #define GGML_MEM_ALIGN 4 -#else - #define GGML_MEM_ALIGN 16 -#endif - -#define UNUSED(x) (void)(x) -#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) - -#define GGML_ASSERT(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ - } \ - } while (0) - -#ifdef GGML_USE_ACCELERATE -#include -#elif GGML_USE_OPENBLAS -#include -#endif - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// floating point type used to accumulate sums -typedef double ggml_float; - -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t -#ifdef __ARM_NEON - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#define GGML_COMPUTE_FP16_TO_FP32(x) (x) -#define GGML_COMPUTE_FP32_TO_FP16(x) (x) - -#define GGML_FP16_TO_FP32(x) (x) -#define GGML_FP32_TO_FP16(x) (x) - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#include -#endif -#endif - -#ifdef __F16C__ - -#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) - -#endif // __F16C__ - -#endif // __ARM_NEON - -// -// global data -// - -// precomputed gelu table for f16 (128 KB) -static ggml_fp16_t table_gelu_f16[1 << 16]; - -// precomputed exp table for f16 (128 KB) -static ggml_fp16_t table_exp_f16[1 << 16]; - -// precomputed f32 table for f16 (256 KB) -static float table_f32_f16[1 << 16]; - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, -// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. -#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) - -inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { - uint16_t s; - memcpy(&s, &f, sizeof(uint16_t)); - return table_f32_f16[s]; -} - -#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) - -#endif - -// note: do not use these inside ggml.c -// these are meant to be used via the ggml.h API -float ggml_fp16_to_fp32(ggml_fp16_t x) { - return GGML_FP16_TO_FP32(x); -} - -ggml_fp16_t ggml_fp32_to_fp16(float x) { - return GGML_FP32_TO_FP16(x); -} - -// -// timing -// - -#if defined(_MSC_VER) || defined(__MINGW32__) -static int64_t timer_freq; -void ggml_time_init(void) { - LARGE_INTEGER frequency; - QueryPerformanceFrequency(&frequency); - timer_freq = frequency.QuadPart; -} -int64_t ggml_time_ms(void) { - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return (t.QuadPart * 1000) / timer_freq; -} -int64_t ggml_time_us(void) { - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return (t.QuadPart * 1000000) / timer_freq; -} -#else -void ggml_time_init(void) {} -int64_t ggml_time_ms(void) { - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; -} - -int64_t ggml_time_us(void) { - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; -} -#endif - -int64_t ggml_cycles(void) { - return clock(); -} - -int64_t ggml_cycles_per_ms(void) { - return CLOCKS_PER_SEC/1000; -} - -#ifdef GGML_PERF -#define ggml_perf_time_ms() ggml_time_ms() -#define ggml_perf_time_us() ggml_time_us() -#define ggml_perf_cycles() ggml_cycles() -#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() -#else -#define ggml_perf_time_ms() 0 -#define ggml_perf_time_us() 0 -#define ggml_perf_cycles() 0 -#define ggml_perf_cycles_per_ms() 0 -#endif - -// -// cache line -// - -#if defined(__cpp_lib_hardware_interference_size) -#define CACHE_LINE_SIZE hardware_destructive_interference_size -#else -#if defined(__POWER9_VECTOR__) -#define CACHE_LINE_SIZE 128 -#else -#define CACHE_LINE_SIZE 64 -#endif -#endif - -static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); - -// -// simd mappings -// - -// we define a common set of C macros which map to specific intrinsics based on the current architecture -// we then implement the fundamental computation operations below using only these macros -// adding support for new architectures requires to define the corresponding SIMD macros -// -// GGML_F32_STEP / GGML_F16_STEP -// number of elements to process in a single step -// -// GGML_F32_EPR / GGML_F16_EPR -// number of elements to fit in a single register -// - -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - -#define GGML_SIMD - -// F32 NEON - -#define GGML_F32_STEP 16 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 float32x4_t -#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) -#define GGML_F32x4_SET1(x) vdupq_n_f32(x) -#define GGML_F32x4_LOAD vld1q_f32 -#define GGML_F32x4_STORE vst1q_f32 -#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) -#define GGML_F32x4_ADD vaddq_f32 -#define GGML_F32x4_MUL vmulq_f32 -#if defined(__ARM_FEATURE_QRDMX) - #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#else - #define GGML_F32x4_REDUCE_ONE(x) \ - (vgetq_lane_f32(x, 0) + \ - vgetq_lane_f32(x, 1) + \ - vgetq_lane_f32(x, 2) + \ - vgetq_lane_f32(x, 3)) -#endif -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ - } \ - res = GGML_F32x4_REDUCE_ONE(x[0]); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - #define GGML_F16_STEP 32 - #define GGML_F16_EPR 8 - - #define GGML_F16x8 float16x8_t - #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) - #define GGML_F16x8_SET1(x) vdupq_n_f16(x) - #define GGML_F16x8_LOAD vld1q_f16 - #define GGML_F16x8_STORE vst1q_f16 - #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) - #define GGML_F16x8_ADD vaddq_f16 - #define GGML_F16x8_MUL vmulq_f16 - #define GGML_F16x8_REDUCE(res, x) \ - { \ - for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ - x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ - x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ - x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ - } \ - const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ - const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ - res = vaddvq_f32(vaddq_f32(t0, t1)); \ - } - - #define GGML_F16_VEC GGML_F16x8 - #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO - #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 - #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) - #define GGML_F16_VEC_FMA GGML_F16x8_FMA - #define GGML_F16_VEC_ADD GGML_F16x8_ADD - #define GGML_F16_VEC_MUL GGML_F16x8_MUL - #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE -#else - // if FP16 vector arithmetic is not supported, we use FP32 instead - // and take advantage of the vcvt_ functions to convert to/from FP16 - - #define GGML_F16_STEP 16 - #define GGML_F16_EPR 4 - - #define GGML_F32Cx4 float32x4_t - #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) - #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) - #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) - #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) - #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) - #define GGML_F32Cx4_ADD vaddq_f32 - #define GGML_F32Cx4_MUL vmulq_f32 - #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE - - #define GGML_F16_VEC GGML_F32Cx4 - #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO - #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 - #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) - #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA - #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD - #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL - #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE -#endif - -#elif defined(__AVX__) - -#define GGML_SIMD - -// F32 AVX - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 8 - -#define GGML_F32x8 __m256 -#define GGML_F32x8_ZERO _mm256_setzero_ps() -#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) -#define GGML_F32x8_LOAD _mm256_loadu_ps -#define GGML_F32x8_STORE _mm256_storeu_ps -#if defined(__FMA__) - #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) -#else - #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) -#endif -#define GGML_F32x8_ADD _mm256_add_ps -#define GGML_F32x8_MUL _mm256_mul_ps -#define GGML_F32x8_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ - } \ - const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ - _mm256_extractf128_ps(x[0], 1)); \ - const __m128 t1 = _mm_hadd_ps(t0, t0); \ - res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x8 -#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD -#define GGML_F32_VEC_STORE GGML_F32x8_STORE -#define GGML_F32_VEC_FMA GGML_F32x8_FMA -#define GGML_F32_VEC_ADD GGML_F32x8_ADD -#define GGML_F32_VEC_MUL GGML_F32x8_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE - -// F16 AVX - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 8 - -// F16 arithmetic is not supported by AVX, so we use F32 instead -// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32 - -#define GGML_F32Cx8 __m256 -#define GGML_F32Cx8_ZERO _mm256_setzero_ps() -#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) -#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) -#define GGML_F32Cx8_FMA GGML_F32x8_FMA -#define GGML_F32Cx8_ADD _mm256_add_ps -#define GGML_F32Cx8_MUL _mm256_mul_ps -#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE - -#define GGML_F16_VEC GGML_F32Cx8 -#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE - -#elif defined(__POWER9_VECTOR__) - -#define GGML_SIMD - -// F32 POWER9 - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 vector float -#define GGML_F32x4_ZERO 0.0f -#define GGML_F32x4_SET1 vec_splats -#define GGML_F32x4_LOAD(p) vec_xl(0, p) -#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) -#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) -#define GGML_F32x4_ADD vec_add -#define GGML_F32x4_MUL vec_mul -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = vec_add(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = vec_add(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = vec_add(x[8*i], x[8*i+4]); \ - } \ - res = vec_extract(x[0], 0) + \ - vec_extract(x[0], 1) + \ - vec_extract(x[0], 2) + \ - vec_extract(x[0], 3); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 POWER9 -#define GGML_F16_STEP GGML_F32_STEP -#define GGML_F16_EPR GGML_F32_EPR -#define GGML_F16_VEC GGML_F32x4 -#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F16_VEC_FMA GGML_F32x4_FMA -#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE -// Use vec_xl, not vec_ld, in case the load address is not aligned. -#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ - vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ - vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] -#define GGML_F16_VEC_STORE(p, r, i) \ - if (i & 0x1) \ - vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ - r[i - GGML_ENDIAN_BYTE(0)]), \ - 0, p - GGML_F16_EPR) - -#elif defined(__wasm_simd128__) - -#define GGML_SIMD - -// F32 WASM - -#define GGML_F32_STEP 16 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 v128_t -#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_F32x4_LOAD wasm_v128_load -#define GGML_F32x4_STORE wasm_v128_store -#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) -#define GGML_F32x4_ADD wasm_f32x4_add -#define GGML_F32x4_MUL wasm_f32x4_mul -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 WASM - -#define GGML_F16_STEP 16 -#define GGML_F16_EPR 4 - -inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { - float tmp[4]; - - tmp[0] = GGML_FP16_TO_FP32(p[0]); - tmp[1] = GGML_FP16_TO_FP32(p[1]); - tmp[2] = GGML_FP16_TO_FP32(p[2]); - tmp[3] = GGML_FP16_TO_FP32(p[3]); - - return wasm_v128_load(tmp); -} - -inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { - float tmp[4]; - - wasm_v128_store(tmp, x); - - p[0] = GGML_FP32_TO_FP16(tmp[0]); - p[1] = GGML_FP32_TO_FP16(tmp[1]); - p[2] = GGML_FP32_TO_FP16(tmp[2]); - p[3] = GGML_FP32_TO_FP16(tmp[3]); -} - -#define GGML_F16x4 v128_t -#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) -#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) -#define GGML_F16x4_FMA GGML_F32x4_FMA -#define GGML_F16x4_ADD wasm_f32x4_add -#define GGML_F16x4_MUL wasm_f32x4_mul -#define GGML_F16x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ - x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ - x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ - x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_F16_VEC GGML_F16x4 -#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO -#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F16x4_FMA -#define GGML_F16_VEC_ADD GGML_F16x4_ADD -#define GGML_F16_VEC_MUL GGML_F16x4_MUL -#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE - -#elif defined(__SSE3__) - -#define GGML_SIMD - -// F32 SSE - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 __m128 -#define GGML_F32x4_ZERO _mm_setzero_ps() -#define GGML_F32x4_SET1(x) _mm_set1_ps(x) -#define GGML_F32x4_LOAD _mm_loadu_ps -#define GGML_F32x4_STORE _mm_storeu_ps -#if defined(__FMA__) - // TODO: Does this work? - #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) -#else - #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) -#endif -#define GGML_F32x4_ADD _mm_add_ps -#define GGML_F32x4_MUL _mm_mul_ps -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \ - } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \ - } \ - const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ - res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ -} -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 SSE - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 4 - -static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) { - float tmp[4]; - - tmp[0] = GGML_FP16_TO_FP32(x[0]); - tmp[1] = GGML_FP16_TO_FP32(x[1]); - tmp[2] = GGML_FP16_TO_FP32(x[2]); - tmp[3] = GGML_FP16_TO_FP32(x[3]); - - return _mm_loadu_ps(tmp); -} - -static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { - float arr[4]; - - _mm_storeu_ps(arr, y); - - x[0] = GGML_FP32_TO_FP16(arr[0]); - x[1] = GGML_FP32_TO_FP16(arr[1]); - x[2] = GGML_FP32_TO_FP16(arr[2]); - x[3] = GGML_FP32_TO_FP16(arr[3]); -} - -#define GGML_F32Cx4 __m128 -#define GGML_F32Cx4_ZERO _mm_setzero_ps() -#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) -#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) -#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) -#define GGML_F32Cx4_FMA GGML_F32x4_FMA -#define GGML_F32Cx4_ADD _mm_add_ps -#define GGML_F32Cx4_MUL _mm_mul_ps -#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE - -#define GGML_F16_VEC GGML_F32Cx4 -#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE - -#endif - -// GGML_F32_ARR / GGML_F16_ARR -// number of registers to use per step -#ifdef GGML_SIMD -#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) -#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) -#endif - -// -// fundamental operations -// - -inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - -inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { - ggml_float sumf = 0.0; - -#ifdef GGML_SIMD - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F32_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - sumf += x[i]*y[i]; - } -#endif - - *s = sumf; -} - -inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { - ggml_float sumf = 0.0; - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - - sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F16_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); - } -#else - for (int i = 0; i < n; ++i) { - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); - } -#endif - - *s = sumf; -} - -// compute GGML_VEC_DOT_UNROLL dot products at once -// xs - x row stride in bytes -inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { - ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; - - ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; - - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); - } - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); - - sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); - } - } - } - - // reduce sum0..sum3 to sum0 - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - GGML_F16_VEC_REDUCE(sumf[k], sum[k]); - } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); - } - } -#else - for (int i = 0; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); - } - } -#endif - - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - s[i] = sumf[i]; - } -} - -inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] += x[i]*v; - } -#endif -} - -inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - GGML_ASSERT(false); - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#else - for (int i = 0; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#endif -} - -//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } -inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] *= v; - } -#endif -} - -inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } -inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } -inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } -inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } -inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } -inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } -inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } - -static const ggml_float GELU_COEF_A = 0.044715; -static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; - -inline static float ggml_gelu_f32(float x) { - return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); -} - -inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = table_gelu_f16[i16[i]]; - } -} - -#ifdef GGML_GELU_FP16 -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); - } -} -#else -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_f32(x[i]); - } -} -#endif - -inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { -#ifndef GGML_USE_ACCELERATE - ggml_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += x[i]; - } - *s = sum; -#else - vDSP_sve(x, 1, s, n); -#endif -} - -inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { -#ifndef GGML_USE_ACCELERATE - ggml_float max = -INFINITY; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - } - *s = max; -#else - vDSP_maxv(x, 1, s, n); -#endif -} - -inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } - -// -// logging -// - -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - -// -// data types -// - -static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { - sizeof(int8_t ), - sizeof(int16_t), - sizeof(int32_t), - sizeof(ggml_fp16_t), - sizeof(float ), -}; - -static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { - "NONE", - - "DUP", - "ADD", - "SUB", - "MUL", - "DIV", - "SQR", - "SQRT", - "SUM", - "MEAN", - "REPEAT", - "ABS", - "SGN", - "NEG", - "STEP", - "RELU", - "GELU", - "NORM", - - "MUL_MAT", - - "SCALE", - "CPY", - "RESHAPE", - "VIEW", - "PERMUTE", - "TRANSPOSE", - "GET_ROWS", - "DIAG_MASK_INF", - "SOFT_MAX", - "ROPE", - "CONV_1D_1S", - "CONV_1D_2S", - - "FLASH_ATTN", - "FLASH_FF", -}; - -static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { - "none", - - "x", - "x+y", - "x-y", - "x*y", - "x/y", - "x^2", - "√x", - "Σx", - "Σx/n", - "repeat(x)", - "abs(x)", - "sgn(x)", - "-x", - "step(x)", - "relu(x)", - "gelu(x)", - "norm(x)", - - "X*Y", - - "x*v", - "x-\\>y", - "reshape(x)", - "view(x)", - "permute(x)", - "transpose(x)", - "get_rows(x)", - "diag_mask_inf(x)", - "soft_max(x)", - "rope(x)", - "conv_1d_1s(x)", - "conv_1d_2s(x)", - - "flash_attn(x)", - "flash_ff(x)", -}; - -// -// ggml object -// - -struct ggml_object { - size_t offs; - size_t size; - - struct ggml_object * next; - - char padding[8]; -}; - -static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); - -static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); -static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); - -// -// ggml context -// - -struct ggml_context { - size_t mem_size; - void * mem_buffer; - bool mem_buffer_owned; - - int n_objects; - - struct ggml_object * objects_begin; - struct ggml_object * objects_end; - - struct ggml_scratch scratch; - struct ggml_scratch scratch_save; -}; - -struct ggml_context_container { - bool used; - - struct ggml_context context; -}; - -// -// compute types -// - -enum ggml_task_type { - GGML_TASK_INIT = 0, - GGML_TASK_COMPUTE, - GGML_TASK_FINALIZE, -}; - -struct ggml_compute_params { - enum ggml_task_type type; - - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; -}; - -// -// ggml state -// - -struct ggml_state { - struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; -}; - -// global state -static struct ggml_state g_state; -static atomic_int g_state_barrier = 0; - -// barrier via spin lock -inline static void ggml_critical_section_start(void) { - int processing = atomic_fetch_add(&g_state_barrier, 1); - - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); // TODO: reconsider this - processing = atomic_fetch_add(&g_state_barrier, 1); - } -} - -// TODO: make this somehow automatically executed -// some sort of "sentry" mechanism -inline static void ggml_critical_section_end(void) { - atomic_fetch_sub(&g_state_barrier, 1); -} - -//////////////////////////////////////////////////////////////////////////////// - -void ggml_print_object(const struct ggml_object * obj) { - GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n", - obj->offs, obj->size, (const void *) obj->next); -} - -void ggml_print_objects(const struct ggml_context * ctx) { - struct ggml_object * obj = ctx->objects_begin; - - GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); - - while (obj != NULL) { - ggml_print_object(obj); - obj = obj->next; - } - - GGML_PRINT("%s: --- end ---\n", __func__); -} - -int ggml_nelements(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} - -int ggml_nrows(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} - -size_t ggml_nbytes(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type]; -} - -size_t ggml_type_size(enum ggml_type type) { - return GGML_TYPE_SIZE[type]; -} - -size_t ggml_element_size(const struct ggml_tensor * tensor) { - return GGML_TYPE_SIZE[tensor->type]; -} - -static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_is_vector(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - (t0->ne[0] == t1->ne[0]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); -} - -static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && - tensor->nb[1] == tensor->nb[0]*tensor->ne[0] && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -} - -static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -} - -static inline bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - (t0->ne[0] == t1->ne[0] ) && - (t0->ne[1] == t1->ne[1] ) && - (t0->ne[2] == t1->ne[2] ) && - (t0->ne[3] == t1->ne[3] ); -} - -// check if t1 can be represented as a repeatition of t0 -static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - (t1->ne[0]%t0->ne[0] == 0) && - (t1->ne[1]%t0->ne[1] == 0) && - (t1->ne[2]%t0->ne[2] == 0) && - (t1->ne[3]%t0->ne[3] == 0); -} - -static inline int ggml_up32(int n) { - return (n + 31) & ~31; -} - -static inline int ggml_up64(int n) { - return (n + 63) & ~63; -} - -static inline int ggml_up(int n, int m) { - // assert m is a power of 2 - GGML_ASSERT((m & (m - 1)) == 0); - return (n + m - 1) & ~(m - 1); -} - -// assert that pointer is aligned to GGML_MEM_ALIGN -#define ggml_assert_aligned(ptr) \ - assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) - -//////////////////////////////////////////////////////////////////////////////// - -struct ggml_context * ggml_init(struct ggml_init_params params) { - // make this function thread safe - ggml_critical_section_start(); - - static bool is_first_call = true; - - if (is_first_call) { - // initialize GELU, EXP and F32 tables - { - const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - - ggml_fp16_t ii; - for (int i = 0; i < (1 << 16); ++i) { - uint16_t ui = i; - memcpy(&ii, &ui, sizeof(ii)); - const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); - table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); - table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); - } - - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - - GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } - - // initialize g_state - { - const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - - g_state = (struct ggml_state) { - /*.contexts =*/ { { 0 } }, - }; - - for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { - g_state.contexts[i].used = false; - } - - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - - GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } - - is_first_call = false; - } - - // find non-used context in g_state - struct ggml_context * ctx = NULL; - - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { - if (!g_state.contexts[i].used) { - g_state.contexts[i].used = true; - ctx = &g_state.contexts[i].context; - - GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); - break; - } - } - - if (ctx == NULL) { - GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); - - ggml_critical_section_end(); - - return NULL; - } - - *ctx = (struct ggml_context) { - /*.mem_size =*/ params.mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), - /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, - /*.n_objects =*/ 0, - /*.objects_begin =*/ NULL, - /*.objects_end =*/ NULL, - /*.scratch =*/ { 0, 0, NULL, }, - /*.scratch_save =*/ { 0, 0, NULL, }, - }; - - ggml_assert_aligned(ctx->mem_buffer); - - GGML_PRINT_DEBUG("%s: context initialized\n", __func__); - - ggml_critical_section_end(); - - return ctx; -} - -void ggml_free(struct ggml_context * ctx) { - // make this function thread safe - ggml_critical_section_start(); - - bool found = false; - - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { - if (&g_state.contexts[i].context == ctx) { - g_state.contexts[i].used = false; - - GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", - __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size); - - if (ctx->mem_buffer_owned) { - free(ctx->mem_buffer); - } - - found = true; - break; - } - } - - if (!found) { - GGML_PRINT_DEBUG("%s: context not found\n", __func__); - } - - ggml_critical_section_end(); -} - -size_t ggml_used_mem(const struct ggml_context * ctx) { - return ctx->objects_end->offs + ctx->objects_end->size; -} - -size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { - const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; - - ctx->scratch = scratch; - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -struct ggml_tensor * ggml_new_tensor_impl( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int* ne, - void* data) { - // always insert objects at the end of the context's memory pool - struct ggml_object * obj_cur = ctx->objects_end; - - const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; - const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; - const size_t cur_end = cur_offs + cur_size; - - size_t size_needed = 0; - - if (data == NULL) { - size_needed += GGML_TYPE_SIZE[type]; - for (int i = 0; i < n_dims; i++) { - size_needed *= ne[i]; - } - // align to GGML_MEM_ALIGN - size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN; - } - - char * const mem_buffer = ctx->mem_buffer; - struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); - - if (ctx->scratch.data == NULL || data != NULL) { - size_needed += sizeof(struct ggml_tensor); - - if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { - GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); - assert(false); - return NULL; - } - - *obj_new = (struct ggml_object) { - .offs = cur_end + GGML_OBJECT_SIZE, - .size = size_needed, - .next = NULL, - }; - } else { - if (ctx->scratch.offs + size_needed > ctx->scratch.size) { - GGML_PRINT("%s: not enough space in the scratch memory\n", __func__); - assert(false); - return NULL; - } - - if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) { - GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size); - assert(false); - return NULL; - } - - data = (char * const) ctx->scratch.data + ctx->scratch.offs; - - *obj_new = (struct ggml_object) { - .offs = cur_end + GGML_OBJECT_SIZE, - .size = sizeof(struct ggml_tensor), - .next = NULL, - }; - - //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed); - - ctx->scratch.offs += size_needed; - } - - if (obj_cur != NULL) { - obj_cur->next = obj_new; - } else { - // this is the first object in this context - ctx->objects_begin = obj_new; - } - - ctx->objects_end = obj_new; - - //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); - - struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs); - - ggml_assert_aligned(result); - - *result = (struct ggml_tensor) { - /*.type =*/ type, - /*.n_dims =*/ n_dims, - /*.ne =*/ { 1, 1, 1, 1 }, - /*.nb =*/ { 0, 0, 0, 0 }, - /*.op =*/ GGML_OP_NONE, - /*.is_param =*/ false, - /*.grad =*/ NULL, - /*.src0 =*/ NULL, - /*.src1 =*/ NULL, - /*.opt =*/ { NULL }, - /*.n_tasks =*/ 0, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - /*.data =*/ data == NULL ? (void *)(result + 1) : data, - /*.pad =*/ { 0 }, - }; - - ggml_assert_aligned(result->data); - - for (int i = 0; i < n_dims; i++) { - result->ne[i] = ne[i]; - } - - result->nb[0] = GGML_TYPE_SIZE[type]; - for (int i = 1; i < GGML_MAX_DIMS; i++) { - result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; - } - - ctx->n_objects++; - - return result; -} - -struct ggml_tensor * ggml_new_tensor( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int * ne) { - return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL); -} - -struct ggml_tensor * ggml_new_tensor_1d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0) { - return ggml_new_tensor(ctx, type, 1, &ne0); -} - -struct ggml_tensor * ggml_new_tensor_2d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1) { - const int ne[2] = { ne0, ne1 }; - return ggml_new_tensor(ctx, type, 2, ne); -} - -struct ggml_tensor * ggml_new_tensor_3d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1, - int ne2) { - const int ne[3] = { ne0, ne1, ne2 }; - return ggml_new_tensor(ctx, type, 3, ne); -} - -struct ggml_tensor * ggml_new_tensor_4d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1, - int ne2, - int ne3) { - const int ne[4] = { ne0, ne1, ne2, ne3 }; - return ggml_new_tensor(ctx, type, 4, ne); -} - -struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { - ctx->scratch_save = ctx->scratch; - ctx->scratch.data = NULL; - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - - ctx->scratch = ctx->scratch_save; - - ggml_set_i32(result, value); - - return result; -} - -struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { - ctx->scratch_save = ctx->scratch; - ctx->scratch.data = NULL; - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); - - ctx->scratch = ctx->scratch_save; - - ggml_set_f32(result, value); - - return result; -} - -struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { - return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); -} - -struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { - memset(tensor->data, 0, ggml_nbytes(tensor)); - return tensor; -} - -struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { - const int n = ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } - - return tensor; -} - -struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { - const int n = ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } - - return tensor; -} - -int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } break; - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } break; - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } break; - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } - - return 0.0f; -} - -void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } break; - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } break; - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } break; - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } - - return 0.0f; -} - -void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -void * ggml_get_data(const struct ggml_tensor * tensor) { - return tensor->data; -} - -float * ggml_get_data_f32(const struct ggml_tensor * tensor) { - assert(tensor->type == GGML_TYPE_F32); - return (float *)(tensor->data); -} - -struct ggml_tensor * ggml_view_tensor( - struct ggml_context * ctx, - const struct ggml_tensor * src) { - return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data); -} - -//////////////////////////////////////////////////////////////////////////////// - -// ggml_dup - -struct ggml_tensor * ggml_dup_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_DUP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_dup( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_dup_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_dup_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_dup_impl(ctx, a, true); -} - -// ggml_add - -struct ggml_tensor * ggml_add_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - assert(ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_ADD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_tensor * ggml_add( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_add_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_add_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_add_impl(ctx, a, b, true); -} - -// ggml_sub - -struct ggml_tensor * ggml_sub_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - assert(ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SUB; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_tensor * ggml_sub( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_sub_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_sub_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_sub_impl(ctx, a, b, true); -} - -// ggml_mul - -struct ggml_tensor * ggml_mul_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - assert(ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - assert(is_node == false); - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_MUL; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_tensor * ggml_mul( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_mul_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_mul_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_mul_impl(ctx, a, b, true); -} - -// ggml_div - -struct ggml_tensor * ggml_div_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - assert(ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - assert(is_node == false); - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_DIV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_tensor * ggml_div( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_div_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_div_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_div_impl(ctx, a, b, true); -} - -// ggml_sqr - -struct ggml_tensor * ggml_sqr_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SQR; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_sqr( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqr_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_sqr_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqr_impl(ctx, a, true); -} - -// ggml_sqrt - -struct ggml_tensor * ggml_sqrt_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SQRT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_sqrt( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqrt_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_sqrt_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqrt_impl(ctx, a, true); -} - -// ggml_sum - -struct ggml_tensor * ggml_sum( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); - - result->op = GGML_OP_SUM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_mean - -struct ggml_tensor * ggml_mean( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement - is_node = true; - } - - int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne); - - result->op = GGML_OP_MEAN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_repeat - -struct ggml_tensor * ggml_repeat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - assert(ggml_can_repeat(a, b)); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (ggml_are_same_shape(a, b) && !is_node) { - return a; - } - - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); - - result->op = GGML_OP_REPEAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_abs - -struct ggml_tensor * ggml_abs_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_ABS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_abs( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_abs_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_abs_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_abs_impl(ctx, a, true); -} - - -// ggml_sgn - -struct ggml_tensor * ggml_sgn_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SGN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_sgn( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sgn_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_sgn_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sgn_impl(ctx, a, true); -} - -// ggml_neg - -struct ggml_tensor * ggml_neg_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_NEG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_neg( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_neg_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_neg_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_neg_impl(ctx, a, true); -} - -// ggml_step - -struct ggml_tensor * ggml_step_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_STEP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_step( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_step_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_step_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_step_impl(ctx, a, true); -} - -// ggml_relu - -struct ggml_tensor * ggml_relu_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_RELU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_relu( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_relu_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_relu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_relu_impl(ctx, a, true); -} - -// ggml_gelu - -struct ggml_tensor * ggml_gelu_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_GELU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_gelu( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_gelu_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_gelu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_gelu_impl(ctx, a, true); -} - -// ggml_norm - -struct ggml_tensor * ggml_norm_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - assert(false); // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store epsilon here? - - return result; -} - -struct ggml_tensor * ggml_norm( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, true); -} - -// ggml_mul_mat - -struct ggml_tensor * ggml_mul_mat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - assert(ggml_can_mul_mat(a, b)); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); - - result->op = GGML_OP_MUL_MAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_scale - -struct ggml_tensor * ggml_scale_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - assert(ggml_is_scalar(b)); - assert(ggml_is_padded_1d(a)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - result->op = GGML_OP_SCALE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_tensor * ggml_scale( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_scale_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_scale_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_scale_impl(ctx, a, b, true); -} - -// ggml_cpy - -struct ggml_tensor * ggml_cpy_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - assert(ggml_nelements(a) == ggml_nelements(b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward - is_node = true; - } - - // make a view of the destination - struct ggml_tensor * result = ggml_view_tensor(ctx, b); - - result->op = GGML_OP_CPY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -struct ggml_tensor * ggml_cpy( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_cpy_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, true); -} - -// ggml_reshape - -struct ggml_tensor * ggml_reshape( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - assert(ggml_is_contiguous(a)); - assert(ggml_is_contiguous(b)); - assert(ggml_nelements(a) == ggml_nelements(b)); - - bool is_node = false; - - if (a->grad || b->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_reshape_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1) { - assert(ggml_is_contiguous(a)); - assert(ggml_nelements(a) == ne0*ne1); - - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - const int ne[2] = { ne0, ne1 }; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -struct ggml_tensor * ggml_reshape_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - int ne2) { - assert(ggml_is_contiguous(a)); - assert(ggml_nelements(a) == ne0*ne1*ne2); - - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - const int ne[3] = { ne0, ne1, ne2 }; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_view_1d - -struct ggml_tensor * ggml_view_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - size_t offset) { - if (a->grad) { - assert(false); // gradient propagation is not supported - } - - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); - - result->op = GGML_OP_VIEW; - result->grad = NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store the offset here? - - return result; -} - -// ggml_view_2d - -struct ggml_tensor * ggml_view_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - size_t nb1, - size_t offset) { - if (a->grad) { - assert(false); // gradient propagation is not supported - } - - const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; - - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); - - result->nb[1] = nb1; - result->nb[2] = result->nb[1]*ne1; - result->nb[3] = result->nb[2]; - - result->op = GGML_OP_VIEW; - result->grad = NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store the offset here? - - return result; -} - -// ggml_permute - -struct ggml_tensor * ggml_permute( - struct ggml_context * ctx, - struct ggml_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3) { - assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS); - assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS); - assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS); - assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS); - - assert(axis0 != axis1); - assert(axis0 != axis2); - assert(axis0 != axis3); - assert(axis1 != axis2); - assert(axis1 != axis3); - assert(axis2 != axis3); - - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - int ne[GGML_MAX_DIMS]; - int nb[GGML_MAX_DIMS]; - - ne[axis0] = a->ne[0]; - ne[axis1] = a->ne[1]; - ne[axis2] = a->ne[2]; - ne[axis3] = a->ne[3]; - - nb[axis0] = a->nb[0]; - nb[axis1] = a->nb[1]; - nb[axis2] = a->nb[2]; - nb[axis3] = a->nb[3]; - - result->ne[0] = ne[0]; - result->ne[1] = ne[1]; - result->ne[2] = ne[2]; - result->ne[3] = ne[3]; - - result->nb[0] = nb[0]; - result->nb[1] = nb[1]; - result->nb[2] = nb[2]; - result->nb[3] = nb[3]; - - result->op = GGML_OP_PERMUTE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; // TODO: maybe store the permutation here? - - return result; -} - -// ggml_transpose - -struct ggml_tensor * ggml_transpose( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - result->ne[0] = a->ne[1]; - result->ne[1] = a->ne[0]; - - result->nb[0] = a->nb[1]; - result->nb[1] = a->nb[0]; - - result->op = GGML_OP_TRANSPOSE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_get_rows - -struct ggml_tensor * ggml_get_rows( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); - - bool is_node = false; - - if (a->grad || b->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - // TODO: implement non F32 return - //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); - - result->op = GGML_OP_GET_ROWS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_diag_mask_inf - -struct ggml_tensor * ggml_diag_mask_inf( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past) { - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - struct ggml_tensor * b = ggml_new_i32(ctx, n_past); - - result->op = GGML_OP_DIAG_MASK_INF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_soft_max - -struct ggml_tensor * ggml_soft_max( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - result->op = GGML_OP_SOFT_MAX; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = NULL; - - return result; -} - -// ggml_rope - -struct ggml_tensor * ggml_rope( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_dims, - int mode) { - assert(n_past >= 0); - bool is_node = false; - - if (a->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_dims; - ((int32_t *) b->data)[2] = mode; - - result->op = GGML_OP_ROPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_conv_1d_1s - -struct ggml_tensor * ggml_conv_1d_1s( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - assert(ggml_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); - bool is_node = false; - - if (a->grad || b->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - const int ne[4] = { b->ne[0], a->ne[2], 1, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - - result->op = GGML_OP_CONV_1D_1S; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_conv_1d_2s - -struct ggml_tensor * ggml_conv_1d_2s( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - assert(ggml_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); - bool is_node = false; - - if (a->grad || b->grad) { - assert(false); // TODO: implement backward - is_node = true; - } - - const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - - result->op = GGML_OP_CONV_1D_2S; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - -// ggml_flash_attn - -struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked) { - assert(ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); - - result->op = GGML_OP_FLASH_ATTN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = q; - result->src1 = k; - result->opt[0] = v; - result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0); - - return result; -} - -// ggml_flash_ff - -struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1) { - assert(ggml_can_mul_mat(b0, a)); - // TODO: more checks - - bool is_node = false; - - if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); - - result->op = GGML_OP_FLASH_FF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b0; - result->opt[0] = b1; - result->opt[1] = c0; - result->opt[2] = c1; - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -void ggml_set_param( - struct ggml_context * ctx, - struct ggml_tensor * tensor) { - tensor->is_param = true; - - assert(tensor->grad == NULL); - tensor->grad = ggml_dup_tensor(ctx, tensor); -} - -// ggml_compute_forward_dup - -static void ggml_compute_forward_dup_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_is_contiguous(dst)); - assert(ggml_nelements(dst) == ggml_nelements(src0)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - if (ggml_is_contiguous(src0) && src0->type == dst->type) { - memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); - return; - } - - if (src0->nb[0] == sizeof(ggml_fp16_t)) { - if (dst->type == GGML_TYPE_F16) { - int id = 0; - const size_t rs = ne00*nb00; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - char * dst_ptr = (char *) dst->data + id*rs; - - memcpy(dst_ptr, src0_ptr, rs); - - id++; - } - } - } - } else if (dst->type == GGML_TYPE_F32) { - int id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); - id++; - } - } - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - int id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); - id++; - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - int id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } -} - -static void ggml_compute_forward_dup_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - if (ggml_is_contiguous(src0) && src0->type == dst->type) { - memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); - return; - } - - if (src0->nb[0] == sizeof(float)) { - if (dst->type == GGML_TYPE_F32) { - int id = 0; - const size_t rs = ne00*nb00; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - char * dst_ptr = (char *) dst->data + id*rs; - - memcpy(dst_ptr, src0_ptr, rs); - - id++; - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - int id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); - id++; - } - } - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - int id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - int id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); - id++; - } - } - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } -} - -static void ggml_compute_forward_dup( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_dup_f16(params, src0, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_dup_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_add - -static void ggml_compute_forward_add_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - - const size_t nb10 = src1->nb[0]; - const size_t nb11 = src1->nb[1]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - const int j0 = (n/nth)*ith; - const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); - - for (int j = j0; j < j1; j++) { - ggml_vec_add_f32(nc, - (float *) ((char *) dst->data + j*nb1), - (float *) ((char *) src0->data + j*nb01), - (float *) ((char *) src1->data + j*nb11)); - } - } else { - // src1 is not contiguous - for (int j = ith; j < n; j += nth) { - float * dst_ptr = (float *) ((char *) dst->data + j*nb1); - float * src0_ptr = (float *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { - float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); - - dst_ptr[i] = src0_ptr[i] + *src1_ptr; - } - } - } -} - -static void ggml_compute_forward_add( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_add_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_sub - -static void ggml_compute_forward_sub_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sub_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_compute_forward_sub( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sub_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_mul - -static void ggml_compute_forward_mul_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_mul_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_compute_forward_mul( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_mul_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_div - -static void ggml_compute_forward_div_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_div_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_compute_forward_div( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_div_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_sqr - -static void ggml_compute_forward_sqr_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sqr_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sqr( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sqr_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_sqrt - -static void ggml_compute_forward_sqrt_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sqrt_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sqrt( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sqrt_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_sum - -static void ggml_compute_forward_sum_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_is_scalar(dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - assert(ggml_is_scalar(dst)); - assert(src0->nb[0] == sizeof(float)); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32(ne00, - (float *) (dst->data), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - } - } - } -} - -static void ggml_compute_forward_sum( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sum_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_mean - -static void ggml_compute_forward_mean_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - - assert(ne0 == 1); - assert(ne1 == ne01); - assert(ne2 == ne02); - assert(ne3 == ne03); - - UNUSED(ne0); - UNUSED(ne1); - UNUSED(ne2); - UNUSED(ne3); - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32(ne00, - (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - - *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; - } - } - } -} - -static void ggml_compute_forward_mean( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_mean_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_repeat - -static void ggml_compute_forward_repeat_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_can_repeat(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // TODO: implement support for rank > 2 tensors - assert(src0->ne[2] == 1); - assert(src0->ne[3] == 1); - assert( dst->ne[2] == 1); - assert( dst->ne[3] == 1); - - const int nc = dst->ne[0]; - const int nr = dst->ne[1]; - const int nc0 = src0->ne[0]; - const int nr0 = src0->ne[1]; - const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat - const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat - - // TODO: support for transposed / permuted tensors - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - // TODO: maybe this is not optimal? - for (int i = 0; i < nrr; i++) { - for (int j = 0; j < ncr; j++) { - for (int k = 0; k < nr0; k++) { - ggml_vec_cpy_f32(nc0, - (float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])), - (float *) ((char *) src0->data + ( k)*(src0->nb[1]))); - } - } - } -} - -static void ggml_compute_forward_repeat( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_repeat_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_abs - -static void ggml_compute_forward_abs_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_abs_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_abs( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_abs_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_sgn - -static void ggml_compute_forward_sgn_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sgn_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sgn( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sgn_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_neg - -static void ggml_compute_forward_neg_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_neg_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_neg( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_neg_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_step - -static void ggml_compute_forward_step_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_step_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_step( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_step_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_relu - -static void ggml_compute_forward_relu_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_relu( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_relu_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_gelu - -static void ggml_compute_forward_gelu_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_gelu( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_gelu_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_norm - -static void ggml_compute_forward_norm_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const ggml_float eps = 1e-5f; // TODO: make this a parameter - - // TODO: optimize - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - ggml_float mean = 0.0; - for (int i00 = 0; i00 < ne00; i00++) { - mean += x[i00]; - } - - mean /= ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - ggml_float sum2 = 0.0; - for (int i00 = 0; i00 < ne00; i00++) { - ggml_float v = x[i00] - mean; - y[i00] = v; - sum2 += v*v; - } - - const float scale = 1.0/sqrt(sum2/ne00 + eps); - - ggml_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void ggml_compute_forward_norm( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_norm_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_mul_mat - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) -// helper function to determine if it is better to use BLAS or not -// for large matrices, BLAS is faster -static bool ggml_compute_forward_mul_mat_use_blas( - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - UNUSED(src0); - - const int ne10 = src1->ne[0]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - - // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ( - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) - )) { - //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); - return true; - } - - return false; -} -#endif - -static void ggml_compute_forward_mul_mat_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - assert(ne02 == ne12); - assert(ne03 == ne13); - assert(ne2 == ne12); - assert(ne3 == ne13); - - // TODO: we don't support permuted src0 - assert(nb00 == sizeof(float) || nb01 == sizeof(float)); - - // dst cannot be transposed or permuted - assert(nb0 == sizeof(float)); - assert(nb0 <= nb1); - assert(nb1 <= nb2); - assert(nb2 <= nb3); - - assert(ne0 == ne01); - assert(ne1 == ne11); - assert(ne2 == ne02); - assert(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->ith != 0) { - return; - } - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - const float * x = (float *) (src0->data); - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - - // zT = y * xT - { - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); - } - } - } - - //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - - if (params->type == GGML_TASK_INIT) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); - - float * const wdata = params->wdata; - - // cols per thread - const int dc = (ne + nth - 1)/nth; - - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); - - ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); - - for (int k = 1; k < nth; k++) { - ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); - } - - return; - } - - if (nb01 >= nb00) { - // TODO: do not support transposed src1 - assert(nb10 == sizeof(float)); - - // parallelize by src0 rows using ggml_vec_dot_f32 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - for (int ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; - - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - ggml_vec_dot_f32(ne00, - (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), - (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); - } - } - } else { - // parallelize by src1 columns using ggml_vec_mad_f32 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst - - // total columns in src1 - const int nc = ne10; - - // columns per thread - const int dc = (nc + nth - 1)/nth; - - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); - - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - float * const wdata = params->wdata; - - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; - - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; - - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; - - assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); - - ggml_vec_mad_f32(ne01, - (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), - (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), - *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); - } - } - } - } - } - - //int64_t t1 = ggml_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} -} - -static void ggml_compute_forward_mul_mat_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // TODO: we don't support permuted src0 - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->ith != 0) { - return; - } - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - float * const wdata = params->wdata; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - { - int id = 0; - for (int i01 = 0; i01 < ne01; ++i01) { - for (int i00 = 0; i00 < ne00; ++i00) { - wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); - } - } - } - - const float * x = wdata; - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); - - // float * z = wdata + ne00*ne01; - - // z = x * yT - //{ - // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - // ne01, ne11, ne00, - // 1.0f, x, ne00, - // y, ne00, - // 0.0f, z, ne11); - //} - - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - - // transpose z - //for (int j = 0; j < ne11; ++j) { - // for (int i = 0; i < ne01; ++i) { - // d[j*ne01 + i] = z[i*ne11 + j]; - // } - //} - - { -#if 1 - // zT = y * xT - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne00, - x, ne00, - 0.0f, d, ne01); -#else - // zT = (xT * y)T - cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, - ne01, ne11, ne10, - 1.0f, x, ne00, - y, ne00, - 0.0f, d, ne01); -#endif - } - } - } - - //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - - if (params->type == GGML_TASK_INIT) { - if (nb01 >= nb00) { - ggml_fp16_t * const wdata = params->wdata; - - int id = 0; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int i10 = 0; i10 < ne10; ++i10) { - wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); - } - } - } - } - - GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); - - return; - } - - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); - - ggml_fp16_t * const wdata = params->wdata; - - // cols per thread - const int dc = (ne + nth - 1)/nth; - - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); - - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); - } - - for (int k = 1; k < nth; k++) { - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); - } - } - - return; - } - - if (nb01 >= nb00) { - // fp16 -> half the size, so divide by 2 - // TODO: do not support transposed src1 - assert(nb10/2 == sizeof(ggml_fp16_t)); - - // parallelize by src0 rows using ggml_vec_dot_f16 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * wdata = params->wdata; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int i13 = i03; - const int i12 = i02; - - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; - - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); - - for (int ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); - } - } - } else { - // parallelize by src1 columns using ggml_vec_mad_f16 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst - - // total columns in src1 - const int nc = ne10; - - // columns per thread - const int dc = (nc + nth - 1)/nth; - - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); - - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - ggml_fp16_t * const wdata = params->wdata; - - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; - - ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; - - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; - - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; - - assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); - - ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); - float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - - ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); - } - } - } - } - } - - //int64_t t1 = ggml_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} -} - -static void ggml_compute_forward_mul_mat( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_mul_mat_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_scale - -static void ggml_compute_forward_scale_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // scale factor - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v); - } -} - -static void ggml_compute_forward_scale( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_scale_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_cpy - -static void ggml_compute_forward_cpy( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - ggml_compute_forward_dup(params, src0, dst); -} - -// ggml_compute_forward_reshape - -static void ggml_compute_forward_reshape( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(src0); - UNUSED(dst); -} - -// ggml_compute_forward_view - -static void ggml_compute_forward_view( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_compute_forward_permute - -static void ggml_compute_forward_permute( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_compute_forward_transpose - -static void ggml_compute_forward_transpose( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_compute_forward_get_rows - -static void ggml_compute_forward_get_rows_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(ggml_fp16_t)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - for (int j = 0; j < nc; ++j) { - ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); - } - } -} - -static void ggml_compute_forward_get_rows_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*dst->nb[1]), - (float *) ((char *) src0->data + r*src0->nb[1])); - } -} - -static void ggml_compute_forward_get_rows( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_get_rows_f16(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_get_rows_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_diag_mask_inf - -static void ggml_compute_forward_diag_mask_inf_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 1); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n_past = ((int32_t *) src1->data)[0]; - - // TODO: handle transposed/permuted matrices - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - const int nr = src0->ne[1]; - const int nz = n/nr; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int k = 0; k < nz; k++) { - for (int j = 0; j < nr; j++) { - for (int i = n_past; i < nc; i++) { - if (i > n_past + j) { - *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY; - } - } - } - } -} - -static void ggml_compute_forward_diag_mask_inf( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_soft_max - -static void ggml_compute_forward_soft_max_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float *p = (float *)((char *) dst->data + i1*dst->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(p[i])); - } -#endif - - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, p); - - ggml_float sum = 0.0; - - uint16_t scvt; - for (int i = 0; i < nc; i++) { - if (p[i] == -INFINITY) { - p[i] = 0.0f; - } else { - //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); - ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); - memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); - sum += val; - p[i] = val; - } - } - - assert(sum > 0.0f); - - sum = 1.0/sum; - ggml_vec_scale_f32(nc, p, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(p[i])); - assert(!isinf(p[i])); - } -#endif - } -} - -static void ggml_compute_forward_soft_max( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_soft_max_f32(params, src0, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_rope - -static void ggml_compute_forward_rope_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - //const int ne0 = src0->ne[0]; - const int ne1 = src0->ne[1]; - const int ne2 = src0->ne[2]; - const int ne3 = src0->ne[3]; - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; - const int nb3 = src0->nb[3]; - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - assert(nb0 == sizeof(float)); - - // TODO: optimize - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int p = (mode == 0 ? n_past + i2 : i2); - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < n_dims; i0 += 2) { - const double theta = pow(10000.0, ((double)-i0)/n_dims); - - const double cos_theta = cos(p*theta); - const double sin_theta = sin(p*theta); - - const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - double x0 = src[0]; - double x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } - } - } -} - -static void ggml_compute_forward_rope( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_rope_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_conv_1d_1s - -static void ggml_compute_forward_conv_1d_1s_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_up32(ne01); - - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f16(ew0, &v, - (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; - } - } - } -} - -static void ggml_compute_forward_conv_1d_1s_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_up32(ne01); - - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - float * const wdata = (float *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; - } - } - } -} - -static void ggml_compute_forward_conv_1d_1s( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_conv_1d_2s - -static void ggml_compute_forward_conv_1d_2s_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_up32(ne01); - - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f16(ew0, &v, - (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; - } - } - } -} - -static void ggml_compute_forward_conv_1d_2s_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - //const int ne03 = src0->ne[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - //const int ne12 = src1->ne[2]; - //const int ne13 = src1->ne[3]; - - //const int ne0 = dst->ne[0]; - //const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - //const int ne = ne0*ne1*ne2*ne3; - - const int nb00 = src0->nb[0]; - const int nb01 = src0->nb[1]; - const int nb02 = src0->nb[2]; - //const int nb03 = src0->nb[3]; - - const int nb10 = src1->nb[0]; - const int nb11 = src1->nb[1]; - //const int nb12 = src1->nb[2]; - //const int nb13 = src1->nb[3]; - - //const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - //const int nb2 = dst->nb[2]; - //const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_up32(ne01); - - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) - { - float * const wdata = (float *) params->wdata + 0; - - for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; - for (int i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; - - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; - } - } - } -} - -static void ggml_compute_forward_conv_1d_2s( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_flash_attn - -static void ggml_compute_forward_flash_attn_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const bool masked, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int neq0 = q->ne[0]; - const int neq1 = q->ne[1]; - const int neq2 = q->ne[2]; - const int neq3 = q->ne[3]; - - const int nek0 = k->ne[0]; - const int nek1 = k->ne[1]; - //const int nek2 = k->ne[2]; - //const int nek3 = k->ne[3]; - - //const int nev0 = v->ne[0]; - const int nev1 = v->ne[1]; - //const int nev2 = v->ne[2]; - //const int nev3 = v->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - - const int nbk0 = k->nb[0]; - const int nbk1 = k->nb[1]; - const int nbk2 = k->nb[2]; - const int nbk3 = k->nb[3]; - - const int nbq0 = q->nb[0]; - const int nbq1 = q->nb[1]; - const int nbq2 = q->nb[2]; - const int nbq3 = q->nb[3]; - - const int nbv0 = v->nb[0]; - const int nbv1 = v->nb[1]; - const int nbv2 = v->nb[2]; - const int nbv3 = v->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int D = neq0; - const int N = neq1; - const int P = nek1 - N; - const int M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(float)); - GGML_ASSERT(nbv0 == sizeof(float)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0/sqrt((double) D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - for (int ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f32(neq0, - S + i1, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - - // scale - ggml_vec_scale_f32(nek1, S, scale); - - if (masked) { - for (int i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } - } - - // softmax - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); - - float sum = 0.0f; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif - } - - assert(sum > 0.0f); - - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - for (int ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - ggml_vec_dot_f32(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - S); - } - } -} - -static void ggml_compute_forward_flash_attn_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const bool masked, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int neq0 = q->ne[0]; - const int neq1 = q->ne[1]; - const int neq2 = q->ne[2]; - const int neq3 = q->ne[3]; - - const int nek0 = k->ne[0]; - const int nek1 = k->ne[1]; - //const int nek2 = k->ne[2]; - //const int nek3 = k->ne[3]; - - //const int nev0 = v->ne[0]; - const int nev1 = v->ne[1]; - //const int nev2 = v->ne[2]; - //const int nev3 = v->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - //const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - - const int nbk0 = k->nb[0]; - const int nbk1 = k->nb[1]; - const int nbk2 = k->nb[2]; - const int nbk3 = k->nb[3]; - - const int nbq0 = q->nb[0]; - const int nbq1 = q->nb[1]; - const int nbq2 = q->nb[2]; - const int nbq3 = q->nb[3]; - - const int nbv0 = v->nb[0]; - const int nbv1 = v->nb[1]; - const int nbv2 = v->nb[2]; - const int nbv3 = v->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int D = neq0; - const int N = neq1; - const int P = nek1 - N; - const int M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0/sqrt((double) D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } - - // scale - ggml_vec_scale_f32(nek1, S, scale); - - if (masked) { - for (int i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } - } - - // softmax - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); - - float sum = 0.0f; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif - } - - assert(sum > 0.0f); - - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - ggml_vec_dot_f16(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - S16); - } - } else { - for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - ggml_vec_dot_f16_unroll(nek1, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - S16); - } - } - } -} - -static void ggml_compute_forward_flash_attn( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const bool masked, - struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -// ggml_compute_forward_flash_ff - -static void ggml_compute_forward_flash_ff_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, // F16 - const struct ggml_tensor * b0, // F16 fc_w - const struct ggml_tensor * b1, // F32 fc_b - const struct ggml_tensor * c0, // F16 proj_w - const struct ggml_tensor * c1, // F32 proj_b - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - const int nea0 = a->ne[0]; - const int nea1 = a->ne[1]; - const int nea2 = a->ne[2]; - const int nea3 = a->ne[3]; - - const int neb00 = b0->ne[0]; - const int neb01 = b0->ne[1]; - //const int neb02 = b0->ne[2]; - //const int neb03 = b0->ne[3]; - - const int neb10 = b1->ne[0]; - const int neb11 = b1->ne[1]; - //const int neb12 = b1->ne[2]; - //const int neb13 = b1->ne[3]; - - const int nec00 = c0->ne[0]; - const int nec01 = c0->ne[1]; - //const int nec02 = c0->ne[2]; - //const int nec03 = c0->ne[3]; - - const int nec10 = c1->ne[0]; - const int nec11 = c1->ne[1]; - //const int nec12 = c1->ne[2]; - //const int nec13 = c1->ne[3]; - - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - //const int ne3 = dst->ne[3]; - - const int nba0 = a->nb[0]; - const int nba1 = a->nb[1]; - const int nba2 = a->nb[2]; - const int nba3 = a->nb[3]; - - const int nbb00 = b0->nb[0]; - const int nbb01 = b0->nb[1]; - const int nbb02 = b0->nb[2]; - const int nbb03 = b0->nb[3]; - - const int nbb10 = b1->nb[0]; - //const int nbb11 = b1->nb[1]; - //const int nbb12 = b1->nb[2]; - //const int nbb13 = b1->nb[3]; - - const int nbc00 = c0->nb[0]; - const int nbc01 = c0->nb[1]; - const int nbc02 = c0->nb[2]; - const int nbc03 = c0->nb[3]; - - const int nbc10 = c1->nb[0]; - //const int nbc11 = c1->nb[1]; - //const int nbc12 = c1->nb[2]; - //const int nbc13 = c1->nb[3]; - - const int nb0 = dst->nb[0]; - const int nb1 = dst->nb[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - const int ith = params->ith; - const int nth = params->nth; - - const int D = nea0; - //const int N = nea1; - const int M = neb01; - - GGML_ASSERT(ne0 == nea0); - GGML_ASSERT(ne1 == nea1); - GGML_ASSERT(ne2 == nea2); - - GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb10 == sizeof(float)); - GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbc10 == sizeof(float)); - - GGML_ASSERT(neb00 == D); - GGML_ASSERT(neb01 == M); - GGML_ASSERT(neb10 == M); - GGML_ASSERT(neb11 == 1); - - GGML_ASSERT(nec00 == M); - GGML_ASSERT(nec01 == D); - GGML_ASSERT(nec10 == D); - GGML_ASSERT(nec11 == 1); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // parallelize by a rows using ggml_vec_dot_f32 - - // total rows in a - const int nr = nea1*nea2*nea3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // a indices - const int ia3 = ir/(nea2*nea1); - const int ia2 = (ir - ia3*nea2*nea1)/nea1; - const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); - - float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); - - for (int ic = 0; ic < neb01; ++ic) { - // b0 indices - const int ib03 = ia3; - const int ib02 = ia2; - const int ib01 = ic; - - // S indices - const int i1 = ib01; - - ggml_vec_dot_f16(nea0, - S + i1, - (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), - (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); - } - - ggml_vec_add_f32(neb01, S, S, (float *) b1->data); - //ggml_vec_gelu_f32(neb01, S, S); - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); - - for (int i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - ggml_vec_gelu_f16(neb01, S16, S16); - - { - // dst indices - const int i1 = ia1; - const int i2 = ia2; - const int i3 = ia3; - - for (int ic = 0; ic < nec01; ++ic) { - - ggml_vec_dot_f16(neb01, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), - S16); - } - - ggml_vec_add_f32(nec01, - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) c1->data); - } - } -} - -static void ggml_compute_forward_flash_ff( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - const struct ggml_tensor * b0, - const struct ggml_tensor * b1, - const struct ggml_tensor * c0, - const struct ggml_tensor * c1, - struct ggml_tensor * dst) { - switch (b0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(false); // TODO - } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - assert(false); - } break; - } -} - -///////////////////////////////// - -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { - assert(params); - - switch (tensor->op) { - case GGML_OP_DUP: - { - ggml_compute_forward_dup(params, tensor->src0, tensor); - } break; - case GGML_OP_ADD: - { - ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_SUB: - { - ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_MUL: - { - ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_DIV: - { - ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_SQR: - { - ggml_compute_forward_sqr(params, tensor->src0, tensor); - } break; - case GGML_OP_SQRT: - { - ggml_compute_forward_sqrt(params, tensor->src0, tensor); - } break; - case GGML_OP_SUM: - { - ggml_compute_forward_sum(params, tensor->src0, tensor); - } break; - case GGML_OP_MEAN: - { - ggml_compute_forward_mean(params, tensor->src0, tensor); - } break; - case GGML_OP_REPEAT: - { - ggml_compute_forward_repeat(params, tensor->src0, tensor); - } break; - case GGML_OP_ABS: - { - ggml_compute_forward_abs(params, tensor->src0, tensor); - } break; - case GGML_OP_SGN: - { - ggml_compute_forward_sgn(params, tensor->src0, tensor); - } break; - case GGML_OP_NEG: - { - ggml_compute_forward_neg(params, tensor->src0, tensor); - } break; - case GGML_OP_STEP: - { - ggml_compute_forward_step(params, tensor->src0, tensor); - } break; - case GGML_OP_RELU: - { - ggml_compute_forward_relu(params, tensor->src0, tensor); - } break; - case GGML_OP_GELU: - { - ggml_compute_forward_gelu(params, tensor->src0, tensor); - } break; - case GGML_OP_NORM: - { - ggml_compute_forward_norm(params, tensor->src0, tensor); - } break; - case GGML_OP_MUL_MAT: - { - ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_SCALE: - { - ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_CPY: - { - ggml_compute_forward_cpy(params, tensor->src0, tensor); - } break; - case GGML_OP_RESHAPE: - { - ggml_compute_forward_reshape(params, tensor->src0, tensor); - } break; - case GGML_OP_VIEW: - { - ggml_compute_forward_view(params, tensor->src0); - } break; - case GGML_OP_PERMUTE: - { - ggml_compute_forward_permute(params, tensor->src0); - } break; - case GGML_OP_TRANSPOSE: - { - ggml_compute_forward_transpose(params, tensor->src0); - } break; - case GGML_OP_GET_ROWS: - { - ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_DIAG_MASK_INF: - { - ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_SOFT_MAX: - { - ggml_compute_forward_soft_max(params, tensor->src0, tensor); - } break; - case GGML_OP_ROPE: - { - ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_CONV_1D_1S: - { - ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_CONV_1D_2S: - { - ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor); - } break; - case GGML_OP_FLASH_ATTN: - { - int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); - } break; - case GGML_OP_FLASH_FF: - { - ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); - } break; - case GGML_OP_NONE: - { - // nop - } break; - case GGML_OP_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -//////////////////////////////////////////////////////////////////////////////// - -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { - struct ggml_tensor * src0 = tensor->src0; - struct ggml_tensor * src1 = tensor->src1; - - switch (tensor->op) { - case GGML_OP_DUP: - { - if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); - } - } break; - case GGML_OP_ADD: - { - if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); - } - if (src1->grad) { - src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); - } - } break; - case GGML_OP_SUB: - { - if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); - } - if (src1->grad) { - src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); - } - } break; - case GGML_OP_MUL: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_mul(ctx, src1, tensor->grad), - inplace); - } - if (src1->grad) { - src1->grad = - ggml_add_impl(ctx, - src1->grad, - ggml_mul(ctx, src0, tensor->grad), - inplace); - } - } break; - case GGML_OP_DIV: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_div(ctx, tensor->grad, src1), - inplace); - } - if (src1->grad) { - src1->grad = - ggml_sub_impl(ctx, - src1->grad, - ggml_mul(ctx, - tensor->grad, - ggml_div(ctx, tensor, src1)), - inplace); - } - } break; - case GGML_OP_SQR: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_mul(ctx, - ggml_mul(ctx, src0, tensor->grad), - ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)), - inplace); - } - } break; - case GGML_OP_SQRT: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_div(ctx, - ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor), - tensor), - inplace); - } - } break; - case GGML_OP_SUM: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_repeat(ctx, tensor->grad, src0->grad), - inplace); - } - } break; - case GGML_OP_MEAN: - { - assert(false); // TODO: implement - } break; - case GGML_OP_REPEAT: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_sum(ctx, tensor->grad), - inplace); - } - } break; - case GGML_OP_ABS: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_mul(ctx, - ggml_sgn(ctx, src0), - tensor->grad), - inplace); - } - } break; - case GGML_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case GGML_OP_NEG: - { - if (src0->grad) { - src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); - } - } break; - case GGML_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case GGML_OP_RELU: - { - if (src0->grad) { - src0->grad = ggml_sub_impl(ctx, - src0->grad, - ggml_mul(ctx, - ggml_step(ctx, src0), - tensor->grad), - inplace); - } - } break; - case GGML_OP_GELU: - { - assert(false); // TODO: not implemented - } break; - case GGML_OP_NORM: - { - assert(false); // TODO: not implemented - } break; - case GGML_OP_MUL_MAT: - { - if (src0->grad) { - // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); - assert(false); - } - if (src1->grad) { - src1->grad = - ggml_add_impl(ctx, - src1->grad, - // TODO: fix transpose, the node will break the graph connections - ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad), - inplace); - } - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CPY: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_RESHAPE: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_VIEW: - { - GGML_ASSERT(false); // not supported - } break; - case GGML_OP_PERMUTE: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_TRANSPOSE: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_GET_ROWS: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_DIAG_MASK_INF: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_1S: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_2S: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_FLASH_ATTN: - { - GGML_ASSERT(false); // not supported - } break; - case GGML_OP_FLASH_FF: - { - GGML_ASSERT(false); // not supported - } break; - case GGML_OP_NONE: - { - // nop - } break; - case GGML_OP_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != GGML_OP_NONE) { - //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - - // check if already visited - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return; - } - } - - for (int i = 0; i < cgraph->n_leafs; i++) { - if (cgraph->leafs[i] == node) { - return; - } - } - - if (node->src0) { - ggml_visit_parents(cgraph, node->src0); - } - - if (node->src1) { - ggml_visit_parents(cgraph, node->src1); - } - - for (int i = 0; i < GGML_MAX_OPT; ++i) { - if (node->opt[i]) { - ggml_visit_parents(cgraph, node->opt[i]); - } - } - - if (node->op == GGML_OP_NONE && node->grad == NULL) { - // reached a leaf node, not part of the gradient graph (e.g. a constant) - assert(cgraph->n_leafs < GGML_MAX_NODES); - - cgraph->leafs[cgraph->n_leafs] = node; - cgraph->n_leafs++; - } else { - assert(cgraph->n_nodes < GGML_MAX_NODES); - - cgraph->nodes[cgraph->n_nodes] = node; - cgraph->grads[cgraph->n_nodes] = node->grad; - cgraph->n_nodes++; - } -} - -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { - if (!expand) { - cgraph->n_nodes = 0; - cgraph->n_leafs = 0; - } - - const int n0 = cgraph->n_nodes; - UNUSED(n0); - - ggml_visit_parents(cgraph, tensor); - - const int n_new = cgraph->n_nodes - n0; - GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); - - if (n_new > 0) { - // the last added node should always be starting point - assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor); - } -} - -void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); -} - -struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { - struct ggml_cgraph result = { - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.n_threads =*/ 0, - /*.work_size =*/ 0, - /*.work =*/ NULL, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - }; - - ggml_build_forward_impl(&result, tensor, false); - - return result; -} - -struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { - struct ggml_cgraph result = *gf; - - assert(gf->n_nodes > 0); - - // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph - if (keep) { - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->grad) { - node->grad = ggml_dup_tensor(ctx, node); - gf->grads[i] = node->grad; - } - } - } - - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_tensor * node = gf->nodes[i]; - - // because we detached the grad nodes from the original graph, we can afford inplace operations - if (node->grad) { - ggml_compute_backward(ctx, node, keep); - } - } - - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->is_param) { - GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_impl(&result, node->grad, true); - } - } - - return result; -} - -// -// thread data -// -// synchronization is done via busy loops -// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops -// - -#ifdef __APPLE__ - -//#include -// -//typedef os_unfair_lock ggml_lock_t; -// -//#define ggml_lock_init(x) UNUSED(x) -//#define ggml_lock_destroy(x) UNUSED(x) -//#define ggml_lock_lock os_unfair_lock_lock -//#define ggml_lock_unlock os_unfair_lock_unlock -// -//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT - -typedef int ggml_lock_t; - -#define ggml_lock_init(x) UNUSED(x) -#define ggml_lock_destroy(x) UNUSED(x) -#define ggml_lock_lock(x) UNUSED(x) -#define ggml_lock_unlock(x) UNUSED(x) - -#define GGML_LOCK_INITIALIZER 0 - -typedef pthread_t ggml_thread_t; - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#else - -//typedef pthread_spinlock_t ggml_lock_t; - -//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) -//#define ggml_lock_destroy pthread_spin_destroy -//#define ggml_lock_lock pthread_spin_lock -//#define ggml_lock_unlock pthread_spin_unlock - -typedef int ggml_lock_t; - -#define ggml_lock_init(x) UNUSED(x) -#define ggml_lock_destroy(x) UNUSED(x) -#define ggml_lock_lock(x) UNUSED(x) -#define ggml_lock_unlock(x) UNUSED(x) - -#define GGML_LOCK_INITIALIZER 0 - -typedef pthread_t ggml_thread_t; - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#endif - -struct ggml_compute_state_shared { - ggml_lock_t spin; - - int n_threads; - - // synchronization primitives - atomic_int n_ready; - atomic_bool has_work; - atomic_bool stop; // stop all threads -}; - -struct ggml_compute_state { - ggml_thread_t thrd; - - struct ggml_compute_params params; - struct ggml_tensor * node; - - struct ggml_compute_state_shared * shared; -}; - -static thread_ret_t ggml_graph_compute_thread(void * data) { - struct ggml_compute_state * state = (struct ggml_compute_state *) data; - - const int n_threads = state->shared->n_threads; - - while (true) { - if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { - atomic_store(&state->shared->has_work, false); - } else { - while (atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; - } - ggml_lock_lock (&state->shared->spin); - ggml_lock_unlock(&state->shared->spin); - } - } - - atomic_fetch_sub(&state->shared->n_ready, 1); - - // wait for work - while (!atomic_load(&state->shared->has_work)) { - if (atomic_load(&state->shared->stop)) { - return 0; - } - ggml_lock_lock (&state->shared->spin); - ggml_lock_unlock(&state->shared->spin); - } - - // check if we should stop - if (atomic_load(&state->shared->stop)) { - break; - } - - if (state->node) { - if (state->params.ith < state->params.nth) { - ggml_compute_forward(&state->params, state->node); - } - - state->node = NULL; - } else { - break; - } - } - - return 0; -} - -void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { - if (cgraph->n_threads <= 0) { - cgraph->n_threads = 8; - } - - const int n_threads = cgraph->n_threads; - - struct ggml_compute_state_shared state_shared = { - /*.spin =*/ GGML_LOCK_INITIALIZER, - /*.n_threads =*/ n_threads, - /*.n_ready =*/ 0, - /*.has_work =*/ false, - /*.stop =*/ false, - }; - struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; - - // create thread pool - if (n_threads > 1) { - ggml_lock_init(&state_shared.spin); - - atomic_store(&state_shared.has_work, true); - - for (int j = 0; j < n_threads - 1; j++) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .params = { - .type = GGML_TASK_COMPUTE, - .ith = j + 1, - .nth = n_threads, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }, - .node = NULL, - .shared = &state_shared, - }; - - int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - assert(rc == 0); - UNUSED(rc); - } - } - - // initialize tasks + work buffer - { - size_t work_size = 0; - - // thread scheduling for the different operations - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; - - switch (node->op) { - case GGML_OP_DUP: - { - node->n_tasks = 1; - } break; - case GGML_OP_ADD: - { - node->n_tasks = n_threads; - } break; - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SUM: - case GGML_OP_MEAN: - case GGML_OP_REPEAT: - case GGML_OP_ABS: - case GGML_OP_SGN: - case GGML_OP_NEG: - case GGML_OP_STEP: - case GGML_OP_RELU: - { - node->n_tasks = 1; - } break; - case GGML_OP_GELU: - { - node->n_tasks = n_threads; - } break; - case GGML_OP_NORM: - { - node->n_tasks = n_threads; - } break; - case GGML_OP_MUL_MAT: - { - node->n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - //const int nr0 = ggml_nrows(node->src0); - //const int nr1 = ggml_nrows(node->src1); - - //node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); - - size_t cur = 0; - - // TODO: better way to determine if the matrix is transposed - if (node->src0->nb[1] < node->src0->nb[0]) { - cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) - } else { - if (node->src0->type == GGML_TYPE_F16 && - node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); - //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); - //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); - //printf("cur = %zu\n", cur); - } else { - cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); - } -#else - cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); -#endif - } else if (node->src0->type == GGML_TYPE_F32 && - node->src1->type == GGML_TYPE_F32) { - cur = 0; - } else { - GGML_ASSERT(false); - } - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_SCALE: - { - node->n_tasks = n_threads; - } break; - case GGML_OP_CPY: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_GET_ROWS: - case GGML_OP_DIAG_MASK_INF: - { - node->n_tasks = 1; - } break; - case GGML_OP_SOFT_MAX: - { - node->n_tasks = n_threads; - } break; - case GGML_OP_ROPE: - { - node->n_tasks = 1; - } break; - case GGML_OP_CONV_1D_1S: - case GGML_OP_CONV_1D_2S: - { - node->n_tasks = n_threads; - - GGML_ASSERT(node->src0->ne[3] == 1); - GGML_ASSERT(node->src1->ne[2] == 1); - GGML_ASSERT(node->src1->ne[3] == 1); - - size_t cur = 0; - const int nk = node->src0->ne[0]; - - if (node->src0->type == GGML_TYPE_F16 && - node->src1->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*( - nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + - ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] - ); - } else if (node->src0->type == GGML_TYPE_F32 && - node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*( - nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + - ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] - ); - } else { - GGML_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_FLASH_ATTN: - { - node->n_tasks = n_threads; - - size_t cur = 0; - - const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); - - if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 - } - - if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_FLASH_FF: - { - node->n_tasks = n_threads; - - size_t cur = 0; - - if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 - } - - if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_NONE: - { - node->n_tasks = 1; - } break; - case GGML_OP_COUNT: - { - assert(false); - } break; - } - } - - if (cgraph->work != NULL && work_size > cgraph->work_size) { - assert(false); // TODO: better handling - } - - if (work_size > 0 && cgraph->work == NULL) { - cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1); - - GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); - cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size); - } - } - - const int64_t perf_start_cycles = ggml_perf_cycles(); - const int64_t perf_start_time_us = ggml_perf_time_us(); - - for (int i = 0; i < cgraph->n_nodes; i++) { - GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); - - struct ggml_tensor * node = cgraph->nodes[i]; - - // TODO: this could be used to avoid unnecessary computations, but it needs to be improved - //if (node->grad == NULL && node->perf_runs > 0) { - // continue; - //} - - const int64_t perf_node_start_cycles = ggml_perf_cycles(); - const int64_t perf_node_start_time_us = ggml_perf_time_us(); - - // INIT - struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_INIT, - /*.ith =*/ 0, - /*.nth =*/ node->n_tasks, - /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, - /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, - }; - - ggml_compute_forward(¶ms, node); - - // COMPUTE - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_compute_params) { - .type = GGML_TASK_COMPUTE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - workers[j].node = node; - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_store(&state_shared.has_work, true); - } - - params.type = GGML_TASK_COMPUTE; - ggml_compute_forward(¶ms, node); - - // wait for thread pool - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - } - - // FINALIZE - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_compute_params) { - .type = GGML_TASK_FINALIZE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - workers[j].node = node; - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) > 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_store(&state_shared.has_work, true); - } - - params.type = GGML_TASK_FINALIZE; - ggml_compute_forward(¶ms, node); - - // wait for thread pool - if (node->n_tasks > 1) { - if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { - atomic_store(&state_shared.has_work, false); - } - - while (atomic_load(&state_shared.has_work)) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - - atomic_fetch_sub(&state_shared.n_ready, 1); - - while (atomic_load(&state_shared.n_ready) != 0) { - ggml_lock_lock (&state_shared.spin); - ggml_lock_unlock(&state_shared.spin); - } - } - - // performance stats (node) - { - int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles; - int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us; - - node->perf_runs++; - node->perf_cycles += perf_cycles_cur; - node->perf_time_us += perf_time_us_cur; - } - } - - // join thread pool - if (n_threads > 1) { - atomic_store(&state_shared.stop, true); - atomic_store(&state_shared.has_work, true); - - for (int j = 0; j < n_threads - 1; j++) { - int rc = ggml_thread_join(workers[j].thrd, NULL); - assert(rc == 0); - UNUSED(rc); - } - - ggml_lock_destroy(&state_shared.spin); - } - - // performance stats (graph) - { - int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; - int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; - - cgraph->perf_runs++; - cgraph->perf_cycles += perf_cycles_cur; - cgraph->perf_time_us += perf_time_us_cur; - - GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", - __func__, cgraph->perf_runs, - (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), - (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, - (double) perf_time_us_cur / 1000.0, - (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); - } -} - -void ggml_graph_reset(struct ggml_cgraph * cgraph) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * grad = cgraph->grads[i]; - - if (grad) { - ggml_set_zero(grad); - } - } -} - -void ggml_graph_print(const struct ggml_cgraph * cgraph) { - int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; - - GGML_PRINT("=== GRAPH ===\n"); - - GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads); - GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size); - - GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; - - perf_total_per_op_us[node->op] += node->perf_time_us; - - GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", - i, - node->ne[0], node->ne[1], node->ne[2], - GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, - (double) node->perf_cycles / (double) ggml_cycles_per_ms(), - (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, - (double) node->perf_time_us / 1000.0, - (double) node->perf_time_us / 1000.0 / node->perf_runs); - } - - GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); - for (int i = 0; i < cgraph->n_leafs; i++) { - struct ggml_tensor * node = cgraph->leafs[i]; - - GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n", - i, - node->ne[0], node->ne[1], - GGML_OP_LABEL[node->op]); - } - - for (int i = 0; i < GGML_OP_COUNT; i++) { - GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0); - } - - GGML_PRINT("========================================\n"); -} - -// check if node is part of the graph -static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { - if (cgraph == NULL) { - return true; - } - - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return true; - } - } - - return false; -} - -static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * parent = cgraph->nodes[i]; - - if (parent->grad == node) { - return parent; - } - } - - return NULL; -} - -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { - char color[16]; - - FILE * fp = fopen(filename, "w"); - assert(fp); - - fprintf(fp, "digraph G {\n"); - fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = LR;\n"); - - for (int i = 0; i < gb->n_nodes; i++) { - struct ggml_tensor * node = gb->nodes[i]; - - if (ggml_graph_get_parent(gb, node) != NULL) { - continue; - } - - if (node->is_param) { - snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { - if (ggml_graph_find(gf, node)) { - snprintf(color, sizeof(color), "green"); - } else { - snprintf(color, sizeof(color), "lightblue"); - } - } else { - snprintf(color, sizeof(color), "white"); - } - - fprintf(fp, " \"%p\" [ \ -style = filled; fillcolor = %s; shape = record; \ -label=\"%d [%d, %d] | %s", - (void *) node, color, - i, node->ne[0], node->ne[1], - GGML_OP_SYMBOL[node->op]); - - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]); - } else { - fprintf(fp, "\"; ]\n"); - } - } - - for (int i = 0; i < gb->n_leafs; i++) { - struct ggml_tensor * node = gb->leafs[i]; - - snprintf(color, sizeof(color), "pink"); - - if (ggml_nelements(node) == 1) { - fprintf(fp, " \"%p\" [ \ -style = filled; fillcolor = %s; shape = record; \ -label=\"%.1e\"; ]\n", - (void *) node, color, ggml_get_f32_1d(node, 0)); - } else { - fprintf(fp, " \"%p\" [ \ -style = filled; fillcolor = %s; shape = record; \ -label=\"CONST %d [%d, %d]\"; ]\n", - (void *) node, color, - i, node->ne[0], node->ne[1]); - } - } - - for (int i = 0; i < gb->n_nodes; i++) { - struct ggml_tensor * node = gb->nodes[i]; - - struct ggml_tensor * parent = ggml_graph_get_parent(gb, node); - - if (node->src0) { - struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0); - - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n", - parent0 ? (void *) parent0 : (void *) node->src0, - parent0 ? "g" : "x", - parent ? (void *) parent : (void *) node, - parent ? "g" : "x", - parent ? "empty" : "vee", - parent ? "dashed" : "solid"); - } - - if (node->src1) { - struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1); - - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n", - parent1 ? (void *) parent1 : (void *) node->src1, - parent1 ? "g" : "x", - parent ? (void *) parent : (void *) node, - parent ? "g" : "x", - parent ? "empty" : "vee", - parent ? "dashed" : "solid"); - } - } - - for (int i = 0; i < gb->n_leafs; i++) { - struct ggml_tensor * node = gb->leafs[i]; - - if (node->src0) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n", - (void *) node->src0, "x", - (void *) node, "x"); - } - - if (node->src1) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n", - (void *) node->src1, "x", - (void *) node, "x"); - } - } - - fprintf(fp, "}\n"); - - fclose(fp); - - GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); -} - -//////////////////////////////////////////////////////////////////////////////// - -static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int ne = ggml_nelements(ps[p]) ; - // TODO: add function to set tensor from array - for (int j = 0; j < ne; ++j) { - ggml_set_f32_1d(ps[p], j, x[i++]); - } - } -} - -static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int j = 0; j < ne; ++j) { - x[i++] = ggml_get_f32_1d(ps[p], j); - } - } -} - -static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int j = 0; j < ne; ++j) { - g[i++] = ggml_get_f32_1d(ps[p]->grad, j); - } - } -} - -// -// ADAM -// -// ref: https://arxiv.org/pdf/1412.6980.pdf -// - -static enum ggml_opt_result ggml_opt_adam( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { - assert(ggml_is_scalar(f)); - - gf->n_threads = params.n_threads; - gb->n_threads = params.n_threads; - - // these will store the parameters we want to optimize - struct ggml_tensor * ps[GGML_MAX_PARAMS]; - - int np = 0; - int nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->is_param) { - GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - assert(np < GGML_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_nelements(gf->nodes[i]); - } - } - - // constants - const float alpha = params.adam.alpha; - const float beta1 = params.adam.beta1; - const float beta2 = params.adam.beta2; - const float eps = params.adam.eps; - - float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters - float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient - float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared - float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment - float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment - float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat - float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat - - float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values - - // initialize - ggml_vec_set_f32(nx, m, 0.0f); - ggml_vec_set_f32(nx, v, 0.0f); - - // update view - ggml_opt_get_params(np, ps, x); - - // compute the function value - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); - - float fx_prev = ggml_get_f32_1d(f, 0); - if (pf) { - pf[0] = fx_prev; - } - - int n_no_improvement = 0; - float fx_best = fx_prev; - - // run the optimizer - for (int t = 0; t < params.adam.n_iter; ++t) { - GGML_PRINT_DEBUG ("=== iter %d ===\n", t); - - GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); - GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); - - for (int i = 0; i < np; ++i) { - GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, - ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); - } - - const int64_t t_start_wall = ggml_time_us(); - const int64_t t_start_cpu = ggml_cycles(); - UNUSED(t_start_wall); - UNUSED(t_start_cpu); - - { - // update the gradient - ggml_opt_get_grad(np, ps, g1); - - // m_t = beta1*m_t-1 + (1 - beta1)*g_t - ggml_vec_scale_f32(nx, m, beta1); - ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); - - // g2 = g1^2 - ggml_vec_sqr_f32 (nx, g2, g1); - - // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 - ggml_vec_scale_f32(nx, v, beta2); - ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); - - // m^hat = m_t / (1 - beta1^t) - // v^hat = v_t / (1 - beta2^t) - // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps) - ggml_vec_cpy_f32 (nx, mh, m); - ggml_vec_cpy_f32 (nx, vh, v); - - ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1))); - ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1))); - - ggml_vec_sqrt_f32 (nx, vh, vh); - ggml_vec_acc1_f32 (nx, vh, eps); - - ggml_vec_div_f32 (nx, mh, mh, vh); - ggml_vec_sub_f32 (nx, x, x, mh); - - // update the parameters - ggml_opt_set_params(np, ps, x); - } - - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); - - const float fx = ggml_get_f32_1d(f, 0); - - // check convergence - if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) { - GGML_PRINT_DEBUG("converged\n"); - - return GGML_OPT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= t) { - const float rate = (pf[t%params.past] - fx)/fx; - - if (fabs(rate) < params.delta) { - return GGML_OPT_OK; - } - } - - pf[t%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx_best > fx) { - fx_best = fx; - n_no_improvement = 0; - } else { - ++n_no_improvement; - - if (n_no_improvement >= params.max_no_improvement) { - return GGML_OPT_OK; - } - } - } - - fx_prev = fx; - - { - const int64_t t_end_cpu = ggml_cycles(); - GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); - UNUSED(t_end_cpu); - - const int64_t t_end_wall = ggml_time_us(); - GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); - UNUSED(t_end_wall); - } - } - - return GGML_OPT_DID_NOT_CONVERGE; -} - -// -// L-BFGS -// -// the L-BFGS implementation below is based on the following implementation: -// -// https://github.com/chokkan/liblbfgs -// - -struct ggml_lbfgs_iteration_data { - float alpha; - float ys; - float * s; - float * y; -}; - -static enum ggml_opt_result linesearch_backtracking( - struct ggml_context * ctx, - const struct ggml_opt_params * params, - int nx, - float * x, - float * fx, - float * g, - float * d, - float * step, - const float * xp, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - const int np, - struct ggml_tensor * ps[]) { - int count = 0; - - float width = 0.0f; - float dg = 0.0f; - float finit = 0.0f; - float dginit = 0.0f; - float dgtest = 0.0f; - - const float dec = 0.5f; - const float inc = 2.1f; - - if (*step <= 0.) { - return GGML_LINESEARCH_INVALID_PARAMETERS; - } - - // compute the initial gradient in the search direction - ggml_vec_dot_f32(nx, &dginit, g, d); - - // make sure that d points to a descent direction - if (0 < dginit) { - return GGML_LINESEARCH_FAIL; - } - - // initialize local variables - finit = *fx; - dgtest = params->lbfgs.ftol*dginit; - - while (true) { - ggml_vec_cpy_f32(nx, x, xp); - ggml_vec_mad_f32(nx, x, d, *step); - - // evaluate the function and gradient values - { - ggml_opt_set_params(np, ps, x); - - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); - - ggml_opt_get_grad(np, ps, g); - - *fx = ggml_get_f32_1d(f, 0); - } - - ++count; - - if (*fx > finit + (*step)*dgtest) { - width = dec; - } else { - // Armijo condition is satisfied - if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { - return count; - } - - ggml_vec_dot_f32(nx, &dg, g, d); - - // check the Wolfe condition - if (dg < params->lbfgs.wolfe * dginit) { - width = inc; - } else { - if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { - // regular Wolfe conditions - return count; - } - - if(dg > -params->lbfgs.wolfe*dginit) { - width = dec; - } else { - // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) - return count; - } - return count; - } - } - - if (*step < params->lbfgs.min_step) { - return GGML_LINESEARCH_MINIMUM_STEP; - } - if (*step > params->lbfgs.max_step) { - return GGML_LINESEARCH_MAXIMUM_STEP; - } - if (params->lbfgs.max_linesearch <= count) { - return GGML_LINESEARCH_MAXIMUM_ITERATIONS; - } - - (*step) *= width; - } - - return GGML_LINESEARCH_FAIL; -} - -static enum ggml_opt_result ggml_opt_lbfgs( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { - if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || - params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { - if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) { - return GGML_OPT_INVALID_WOLFE; - } - } - - gf->n_threads = params.n_threads; - gb->n_threads = params.n_threads; - - const int m = params.lbfgs.m; - - // these will store the parameters we want to optimize - struct ggml_tensor * ps[GGML_MAX_PARAMS]; - - int np = 0; - int nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->is_param) { - GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - assert(np < GGML_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_nelements(gf->nodes[i]); - } - } - - float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters - float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters - float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient - float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient - float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction - - float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values - - float fx = 0.0f; // cost function value - float xnorm = 0.0f; // ||x|| - float gnorm = 0.0f; // ||g|| - float step = 0.0f; - - // initialize x from the graph nodes - ggml_opt_get_params(np, ps, x); - - // the L-BFGS memory - struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m); - - for (int i = 0; i < m; ++i) { - lm[i].alpha = 0.0f; - lm[i].ys = 0.0f; - lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; - lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; - } - - // evaluate the function value and its gradient - { - ggml_opt_set_params(np, ps, x); - - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); - - ggml_opt_get_grad(np, ps, g); - - fx = ggml_get_f32_1d(f, 0); - } - - if (pf) { - pf[0] = fx; - } - - float fx_best = fx; - - // search direction = -gradient - ggml_vec_neg_f32(nx, d, g); - - // ||x||, ||g|| - ggml_vec_norm_f32(nx, &xnorm, x); - ggml_vec_norm_f32(nx, &gnorm, g); - - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - - // already optimized - if (gnorm/xnorm <= params.lbfgs.eps) { - return GGML_OPT_OK; - } - - // initial step - ggml_vec_norm_inv_f32(nx, &step, d); - - int j = 0; - int k = 1; - int ls = 0; - int end = 0; - int bound = 0; - int n_no_improvement = 0; - - float ys = 0.0f; - float yy = 0.0f; - float beta = 0.0f; - - while (true) { - // store the current position and gradient vectors - ggml_vec_cpy_f32(nx, xp, x); - ggml_vec_cpy_f32(nx, gp, g); - - ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps); - - if (ls < 0) { - // linesearch failed - go back to the previous point and return - ggml_vec_cpy_f32(nx, x, xp); - ggml_vec_cpy_f32(nx, g, gp); - - return ls; - } - - ggml_vec_norm_f32(nx, &xnorm, x); - ggml_vec_norm_f32(nx, &gnorm, g); - - GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - - if (xnorm < 1.0) { - xnorm = 1.0; - } - if (gnorm/xnorm <= params.lbfgs.eps) { - // converged - return GGML_OPT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= k) { - const float rate = (pf[k%params.past] - fx)/fx; - - if (fabs(rate) < params.delta) { - return GGML_OPT_OK; - } - } - - pf[k%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx < fx_best) { - fx_best = fx; - n_no_improvement = 0; - } else { - n_no_improvement++; - - if (n_no_improvement >= params.max_no_improvement) { - return GGML_OPT_OK; - } - } - } - - if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) { - // reached the maximum number of iterations - return GGML_OPT_DID_NOT_CONVERGE; - } - - // update vectors s and y: - // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. - // y_{k+1} = g_{k+1} - g_{k}. - // - ggml_vec_sub_f32(nx, lm[end].s, x, xp); - ggml_vec_sub_f32(nx, lm[end].y, g, gp); - - // compute scalars ys and yy: - // ys = y^t \cdot s -> 1 / \rho. - // yy = y^t \cdot y. - // - ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s); - ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y); - - lm[end].ys = ys; - - // find new search direction - // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS - - bound = (m <= k) ? m : k; - k++; - end = (end + 1)%m; - - // initialize search direction with -g - ggml_vec_neg_f32(nx, d, g); - - j = end; - for (int i = 0; i < bound; ++i) { - j = (j + m - 1) % m; - // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} - ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d); - lm[j].alpha /= lm[j].ys; - // q_{i} = q_{i+1} - \alpha_{i} y_{i} - ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha); - } - - ggml_vec_scale_f32(nx, d, ys/yy); - - for (int i = 0; i < bound; ++i) { - // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} - ggml_vec_dot_f32(nx, &beta, lm[j].y, d); - beta /= lm[j].ys; - // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} - ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta); - j = (j + 1)%m; - } - - step = 1.0; - } - - return GGML_OPT_DID_NOT_CONVERGE; -} - -struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { - struct ggml_opt_params result; - - switch (type) { - case GGML_OPT_ADAM: - { - result = (struct ggml_opt_params) { - .type = GGML_OPT_ADAM, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 100, - - .print_forward_graph = true, - .print_backward_graph = true, - - .adam = { - .n_iter = 10000, - .alpha = 0.001f, - .beta1 = 0.9f, - .beta2 = 0.999f, - .eps = 1e-8f, - .eps_f = 1e-5f, - .eps_g = 1e-3f, - }, - }; - } break; - case GGML_OPT_LBFGS: - { - result = (struct ggml_opt_params) { - .type = GGML_OPT_LBFGS, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 0, - - .print_forward_graph = true, - .print_backward_graph = true, - - .lbfgs = { - .m = 6, - .n_iter = 100, - .max_linesearch = 20, - - .eps = 1e-5f, - .ftol = 1e-4f, - .wolfe = 0.9f, - .min_step = 1e-20f, - .max_step = 1e+20f, - - .linesearch = GGML_LINESEARCH_DEFAULT, - }, - }; - } break; - } - - return result; -} - -enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f) { - bool free_ctx = false; - if (ctx == NULL) { - struct ggml_init_params params_ctx = { - .mem_size = 16*1024*1024, - .mem_buffer = NULL, - }; - - ctx = ggml_init(params_ctx); - if (ctx == NULL) { - return GGML_OPT_NO_CONTEXT; - } - - free_ctx = true; - } - - enum ggml_opt_result result = GGML_OPT_OK; - - // build forward + backward compute graphs - struct ggml_cgraph gf = ggml_build_forward (f); - struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false); - - switch (params.type) { - case GGML_OPT_ADAM: - { - result = ggml_opt_adam(ctx, params, f, &gf, &gb); - } break; - case GGML_OPT_LBFGS: - { - result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb); - } break; - } - - if (params.print_forward_graph) { - ggml_graph_print (&gf); - ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot"); - } - - if (params.print_backward_graph) { - ggml_graph_print (&gb); - ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot"); - } - - if (free_ctx) { - ggml_free(ctx); - } - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -int ggml_cpu_has_avx(void) { -#if defined(__AVX__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx2(void) { -#if defined(__AVX2__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512(void) { -#if defined(__AVX512F__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_fma(void) { -#if defined(__FMA__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_neon(void) { -#if defined(__ARM_NEON) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_arm_fma(void) { -#if defined(__ARM_FEATURE_FMA) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_f16c(void) { -#if defined(__F16C__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_fp16_va(void) { -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_wasm_simd(void) { -#if defined(__wasm_simd128__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_sse3(void) { -#if defined(__SSE3__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_vsx(void) { -#if defined(__POWER9_VECTOR__) - return 1; -#else - return 0; -#endif -} - -//////////////////////////////////////////////////////////////////////////////// diff --git a/bindings/ruby/ext/ggml.h b/bindings/ruby/ext/ggml.h deleted file mode 100644 index 18f317bec04..00000000000 --- a/bindings/ruby/ext/ggml.h +++ /dev/null @@ -1,748 +0,0 @@ -#pragma once - -// -// GGML Tensor Library -// -// This documentation is still a work in progress. -// If you wish some specific topics to be covered, feel free to drop a comment: -// -// https://github.com/ggerganov/whisper.cpp/issues/40 -// -// ## Overview -// -// This library implements: -// -// - a set of tensor operations -// - automatic differentiation -// - basic optimization algorithms -// -// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, -// but is not limited to, the following: -// -// - linear regression -// - support vector machines -// - neural networks -// -// The library allows the user to define a certain function using the available tensor operations. This function -// definition is represented internally via a computation graph. Each tensor operation in the function definition -// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the -// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized -// using one of the available optimization algorithms. -// -// For example, here we define the function: f(x) = a*x^2 + b -// -// { -// struct ggml_init_params params = { -// .mem_size = 16*1024*1024, -// .mem_buffer = NULL, -// }; -// -// // memory allocation happens here -// struct ggml_context * ctx = ggml_init(params); -// -// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); -// -// ggml_set_param(ctx, x); // x is an input variable -// -// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); -// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); -// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); -// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); -// -// ... -// } -// -// Notice that the function definition above does not involve any actual computation. The computation is performed only -// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: -// -// { -// ... -// -// struct ggml_cgraph gf = ggml_build_forward(f); -// -// // set the input variable and parameter values -// ggml_set_f32(x, 2.0f); -// ggml_set_f32(a, 3.0f); -// ggml_set_f32(b, 4.0f); -// -// ggml_graph_compute(ctx0, &gf); -// -// printf("f = %f\n", ggml_get_f32_1d(f, 0)); -// -// ... -// } -// -// The actual computation is performed in the ggml_graph_compute() function. -// -// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the -// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know -// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory -// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was -// actually needed. -// -// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic -// differentiation and optimization algorithms. -// -// The described approach allows to define the function graph once and then compute its forward or backward graphs -// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way -// the user can avoid the memory allocation overhead at runtime. -// -// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class -// citizens, but in theory the library can be extended to support FP8 and integer data types. -// -// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary -// and binary operations. Most of the available operations fall into one of these two categories. With time, it became -// clear that the library needs to support more complex operations. The way to support these operations is not clear -// yet, but a few examples are demonstrated in the following operations: -// -// - ggml_permute() -// - ggml_conv_1d_1s() -// - ggml_conv_1d_2s() -// -// For each tensor operator, the library implements a forward and backward computation function. The forward function -// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the -// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a -// calculus class, or watch the following video: -// -// What is Automatic Differentiation? -// https://www.youtube.com/watch?v=wG_nF1awSSY -// -// -// ## Tensor data (struct ggml_tensor) -// -// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of -// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains -// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: -// -// { -// struct ggml_tensor * c = ggml_add(ctx, a, b); -// -// assert(c->src[0] == a); -// assert(c->src[1] == b); -// } -// -// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the -// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows -// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and -// permutation. All tensor operations have to take the stride into account and not assume that the tensor is -// contiguous in memory. -// -// The data of the tensor is accessed via the "data" pointer. For example: -// -// { -// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3); -// -// // a[1, 2] = 1.0f; -// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f; -// -// // a[2, 0] = 2.0f; -// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f; -// -// ... -// } -// -// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. -// -// ## The matrix multiplication operator (ggml_mul_mat) -// -// TODO -// -// -// ## Multi-threading -// -// TODO -// -// -// ## Overview of ggml.c -// -// TODO -// -// -// ## SIMD optimizations -// -// TODO -// -// -// ## Debugging ggml -// -// TODO -// -// - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include - -#define GGML_MAX_DIMS 4 -#define GGML_MAX_NODES 4096 -#define GGML_MAX_PARAMS 16 -#define GGML_MAX_CONTEXTS 64 -#define GGML_MAX_OPT 4 - -#ifdef __ARM_NEON -// we use the built-in 16-bit float type -typedef __fp16 ggml_fp16_t; -#else -typedef uint16_t ggml_fp16_t; -#endif - -// convert FP16 <-> FP32 -float ggml_fp16_to_fp32(ggml_fp16_t x); -ggml_fp16_t ggml_fp32_to_fp16(float x); - -struct ggml_object; -struct ggml_context; - -enum ggml_type { - GGML_TYPE_I8, - GGML_TYPE_I16, - GGML_TYPE_I32, - GGML_TYPE_F16, - GGML_TYPE_F32, - GGML_TYPE_COUNT, -}; - -// available tensor operations: -enum ggml_op { - GGML_OP_NONE = 0, - - GGML_OP_DUP, - GGML_OP_ADD, - GGML_OP_SUB, - GGML_OP_MUL, - GGML_OP_DIV, - GGML_OP_SQR, - GGML_OP_SQRT, - GGML_OP_SUM, - GGML_OP_MEAN, - GGML_OP_REPEAT, - GGML_OP_ABS, - GGML_OP_SGN, - GGML_OP_NEG, - GGML_OP_STEP, - GGML_OP_RELU, - GGML_OP_GELU, - GGML_OP_NORM, // normalize - - GGML_OP_MUL_MAT, - - GGML_OP_SCALE, - GGML_OP_CPY, - GGML_OP_RESHAPE, - GGML_OP_VIEW, - GGML_OP_PERMUTE, - GGML_OP_TRANSPOSE, - GGML_OP_GET_ROWS, - GGML_OP_DIAG_MASK_INF, - GGML_OP_SOFT_MAX, - GGML_OP_ROPE, - GGML_OP_CONV_1D_1S, - GGML_OP_CONV_1D_2S, - - GGML_OP_FLASH_ATTN, - GGML_OP_FLASH_FF, - - GGML_OP_COUNT, -}; - -// n-dimensional tensor -struct ggml_tensor { - enum ggml_type type; - - int n_dims; - int ne[GGML_MAX_DIMS]; // number of elements - size_t nb[GGML_MAX_DIMS]; // stride in bytes: - // nb[0] = sizeof(type) - // nb[1] = nb[0] * ne[0] + padding - // nb[i] = nb[i-1] * ne[i-1] - - // compute data - enum ggml_op op; - - bool is_param; - - struct ggml_tensor * grad; - struct ggml_tensor * src0; - struct ggml_tensor * src1; - struct ggml_tensor * opt[GGML_MAX_OPT]; - - // thread scheduling - int n_tasks; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; - - void * data; - char padding[8]; -}; - -// computation graph -struct ggml_cgraph { - int n_nodes; - int n_leafs; - int n_threads; - - size_t work_size; - struct ggml_tensor * work; - - struct ggml_tensor * nodes[GGML_MAX_NODES]; - struct ggml_tensor * grads[GGML_MAX_NODES]; - struct ggml_tensor * leafs[GGML_MAX_NODES]; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; -}; - -// scratch buffer -struct ggml_scratch { - size_t offs; - size_t size; - void * data; -}; - -struct ggml_init_params { - // memory pool - size_t mem_size; // bytes - void * mem_buffer; // if NULL, memory will be allocated internally -}; - -void ggml_time_init(void); // call this once at the beginning of the program -int64_t ggml_time_ms(void); -int64_t ggml_time_us(void); -int64_t ggml_cycles(void); -int64_t ggml_cycles_per_ms(void); - -void ggml_print_object (const struct ggml_object * obj); -void ggml_print_objects(const struct ggml_context * ctx); - -int ggml_nelements(const struct ggml_tensor * tensor); -size_t ggml_nbytes (const struct ggml_tensor * tensor); - -size_t ggml_type_size (enum ggml_type type); -size_t ggml_element_size(const struct ggml_tensor * tensor); - -struct ggml_context * ggml_init(struct ggml_init_params params); -void ggml_free(struct ggml_context * ctx); - -size_t ggml_used_mem(const struct ggml_context * ctx); - -size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); - -struct ggml_tensor * ggml_new_tensor( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int *ne); - -struct ggml_tensor * ggml_new_tensor_1d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0); - -struct ggml_tensor * ggml_new_tensor_2d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1); - -struct ggml_tensor * ggml_new_tensor_3d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1, - int ne2); - -struct ggml_tensor * ggml_new_tensor_4d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1, - int ne2, - int ne3); - -struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); -struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); - -struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); -struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); - -struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); -struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); -struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); - -int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); -void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); - -float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); -void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); - - void * ggml_get_data (const struct ggml_tensor * tensor); -float * ggml_get_data_f32(const struct ggml_tensor * tensor); - -// -// operations on tensors with backpropagation -// - -struct ggml_tensor * ggml_dup( - struct ggml_context * ctx, - struct ggml_tensor * a); - -struct ggml_tensor * ggml_add( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -struct ggml_tensor * ggml_sub( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -struct ggml_tensor * ggml_mul( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -struct ggml_tensor * ggml_div( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -struct ggml_tensor * ggml_sqr( - struct ggml_context * ctx, - struct ggml_tensor * a); - -struct ggml_tensor * ggml_sqrt( - struct ggml_context * ctx, - struct ggml_tensor * a); - -// return scalar -// TODO: compute sum along rows -struct ggml_tensor * ggml_sum( - struct ggml_context * ctx, - struct ggml_tensor * a); - -// mean along rows -struct ggml_tensor * ggml_mean( - struct ggml_context * ctx, - struct ggml_tensor * a); - -// if a is the same shape as b, and a is not parameter, return a -// otherwise, return a new tensor: repeat(a) to fit in b -struct ggml_tensor * ggml_repeat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -struct ggml_tensor * ggml_abs( - struct ggml_context * ctx, - struct ggml_tensor * a); - -struct ggml_tensor * ggml_sgn( - struct ggml_context * ctx, - struct ggml_tensor * a); - -struct ggml_tensor * ggml_neg( - struct ggml_context * ctx, - struct ggml_tensor * a); - -struct ggml_tensor * ggml_step( - struct ggml_context * ctx, - struct ggml_tensor * a); - -struct ggml_tensor * ggml_relu( - struct ggml_context * ctx, - struct ggml_tensor * a); - -// TODO: double-check this computation is correct -struct ggml_tensor * ggml_gelu( - struct ggml_context * ctx, - struct ggml_tensor * a); - -// normalize along rows -// TODO: eps is hardcoded to 1e-5 for now -struct ggml_tensor * ggml_norm( - struct ggml_context * ctx, - struct ggml_tensor * a); - -// A: m rows, n columns -// B: p rows, n columns (i.e. we transpose it internally) -// result is m columns, p rows -struct ggml_tensor * ggml_mul_mat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -// -// operations on tensors without backpropagation -// - -// in-place, returns view(a) -struct ggml_tensor * ggml_scale( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -// a -> b, return view(b) -struct ggml_tensor * ggml_cpy( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -// return view(a), b specifies the new shape -// TODO: when we start computing gradient, make a copy instead of view -struct ggml_tensor * ggml_reshape( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -// return view(a) -// TODO: when we start computing gradient, make a copy instead of view -struct ggml_tensor * ggml_reshape_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1); - -// return view(a) -// TODO: when we start computing gradient, make a copy instead of view -struct ggml_tensor * ggml_reshape_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - int ne2); - -// offset in bytes -struct ggml_tensor * ggml_view_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - size_t offset); - -struct ggml_tensor * ggml_view_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - size_t nb1, // row stride in bytes - size_t offset); - -struct ggml_tensor * ggml_permute( - struct ggml_context * ctx, - struct ggml_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3); - -// alias for ggml_permute(ctx, a, 1, 0, 2, 3) -struct ggml_tensor * ggml_transpose( - struct ggml_context * ctx, - struct ggml_tensor * a); - -struct ggml_tensor * ggml_get_rows( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -// set elements above the diagonal to -INF -// in-place, returns view(a) -struct ggml_tensor * ggml_diag_mask_inf( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past); - -// in-place, returns view(a) -struct ggml_tensor * ggml_soft_max( - struct ggml_context * ctx, - struct ggml_tensor * a); - -// rotary position embedding -// in-place, returns view(a) -// if mode == 1, skip n_past elements -// TODO: avoid creating a new tensor every time -struct ggml_tensor * ggml_rope( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_dims, - int mode); - -// padding = 1 -// TODO: we don't support extra parameters for now -// that's why we are hard-coding the stride, padding, and dilation -// not great .. -struct ggml_tensor * ggml_conv_1d_1s( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -struct ggml_tensor * ggml_conv_1d_2s( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - -struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked); - -struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1); - -// -// automatic differentiation -// - -void ggml_set_param( - struct ggml_context * ctx, - struct ggml_tensor * tensor); - -void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); - -struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); -struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); - -void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); -void ggml_graph_reset (struct ggml_cgraph * cgraph); - -// print info and performance information for the graph -void ggml_graph_print(const struct ggml_cgraph * cgraph); - -// dump the graph into a file using the dot format -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); - -// -// optimization -// - -// optimization methods -enum ggml_opt_type { - GGML_OPT_ADAM, - GGML_OPT_LBFGS, -}; - -// linesearch methods -enum ggml_linesearch { - GGML_LINESEARCH_DEFAULT = 1, - - GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, - GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, - GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, -}; - -// optimization return values -enum ggml_opt_result { - GGML_OPT_OK = 0, - GGML_OPT_DID_NOT_CONVERGE, - GGML_OPT_NO_CONTEXT, - GGML_OPT_INVALID_WOLFE, - GGML_OPT_FAIL, - - GGML_LINESEARCH_FAIL = -128, - GGML_LINESEARCH_MINIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_ITERATIONS, - GGML_LINESEARCH_INVALID_PARAMETERS, -}; - -// optimization parameters -// -// see ggml.c (ggml_opt_default_params) for default values -// -struct ggml_opt_params { - enum ggml_opt_type type; - - int n_threads; - - // delta-based convergence test - // - // if past == 0 - disabled - // if past > 0: - // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) - // - int past; - float delta; - - // maximum number of iterations without improvement - // - // if 0 - disabled - // if > 0: - // assume convergence if no cost improvement in this number of iterations - // - int max_no_improvement; - - bool print_forward_graph; - bool print_backward_graph; - - // ADAM parameters - struct { - int n_iter; - - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum ggml_linesearch linesearch; - } lbfgs; -}; - -struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); - -// optimize the function defined by the tensor f -enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f); - -// -// system info -// - -int ggml_cpu_has_avx(void); -int ggml_cpu_has_avx2(void); -int ggml_cpu_has_avx512(void); -int ggml_cpu_has_fma(void); -int ggml_cpu_has_neon(void); -int ggml_cpu_has_arm_fma(void); -int ggml_cpu_has_f16c(void); -int ggml_cpu_has_fp16_va(void); -int ggml_cpu_has_wasm_simd(void); -int ggml_cpu_has_blas(void); -int ggml_cpu_has_sse3(void); -int ggml_cpu_has_vsx(void); - -#ifdef __cplusplus -} -#endif diff --git a/bindings/ruby/ext/whisper.cpp b/bindings/ruby/ext/whisper.cpp deleted file mode 100644 index 24e16bd5cf9..00000000000 --- a/bindings/ruby/ext/whisper.cpp +++ /dev/null @@ -1,4814 +0,0 @@ -#define WHISPER_BUILD -#include "whisper.h" - -#include "ggml.h" - -#include -#include -#define _USE_MATH_DEFINES -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(GGML_BIG_ENDIAN) -#include - -template -static T byteswap(T value) { - return std::byteswap(value); -} - -template<> -float byteswap(float value) { - return std::bit_cast(byteswap(std::bit_cast(value))); -} - -template -static void byteswap_tensor_data(ggml_tensor * tensor) { - T * datum = reinterpret_cast(tensor->data); - for (int i = 0; i < ggml_nelements(tensor); i++) { - datum[i] = byteswap(datum[i]); - } -} - -static void byteswap_tensor(ggml_tensor * tensor) { - switch (tensor->type) { - case GGML_TYPE_I16: { - byteswap_tensor_data(tensor); - break; - } - case GGML_TYPE_F16: { - byteswap_tensor_data(tensor); - break; - } - case GGML_TYPE_I32: { - byteswap_tensor_data(tensor); - break; - } - case GGML_TYPE_F32: { - byteswap_tensor_data(tensor); - break; - } - default: { // GML_TYPE_I8 - break; - } - } -} - -#define BYTESWAP_VALUE(d) d = byteswap(d) -#define BYTESWAP_FILTERS(f) \ - do { \ - for (auto & datum : f.data) { \ - datum = byteswap(datum); \ - } \ - } while (0) -#define BYTESWAP_TENSOR(t) \ - do { \ - byteswap_tensor(tensor); \ - } while (0) -#else -#define BYTESWAP_VALUE(d) do {} while (0) -#define BYTESWAP_FILTERS(f) do {} while (0) -#define BYTESWAP_TENSOR(t) do {} while (0) -#endif - -#define WHISPER_ASSERT(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ - } \ - } while (0) - -// define this to enable verbose trace logging - useful for debugging purposes -//#define WHISPER_DEBUG - -#if defined(WHISPER_DEBUG) -#define WHISPER_PRINT_DEBUG(...) \ - do { \ - fprintf(stderr, __VA_ARGS__); \ - } while (0) -#else -#define WHISPER_PRINT_DEBUG(...) -#endif - -#define WHISPER_USE_FLASH_ATTN -//#define WHISPER_USE_FLASH_FF -#define WHISPER_MAX_DECODERS 16 - -#define WHISPER_USE_SCRATCH -#define WHISPER_MAX_SCRATCH_BUFFERS 16 - -// available whisper models -enum e_model { - MODEL_UNKNOWN, - MODEL_TINY, - MODEL_BASE, - MODEL_SMALL, - MODEL_MEDIUM, - MODEL_LARGE, -}; - -static const std::map> g_lang = { - { "en", { 0, "english", } }, - { "zh", { 1, "chinese", } }, - { "de", { 2, "german", } }, - { "es", { 3, "spanish", } }, - { "ru", { 4, "russian", } }, - { "ko", { 5, "korean", } }, - { "fr", { 6, "french", } }, - { "ja", { 7, "japanese", } }, - { "pt", { 8, "portuguese", } }, - { "tr", { 9, "turkish", } }, - { "pl", { 10, "polish", } }, - { "ca", { 11, "catalan", } }, - { "nl", { 12, "dutch", } }, - { "ar", { 13, "arabic", } }, - { "sv", { 14, "swedish", } }, - { "it", { 15, "italian", } }, - { "id", { 16, "indonesian", } }, - { "hi", { 17, "hindi", } }, - { "fi", { 18, "finnish", } }, - { "vi", { 19, "vietnamese", } }, - { "iw", { 20, "hebrew", } }, - { "uk", { 21, "ukrainian", } }, - { "el", { 22, "greek", } }, - { "ms", { 23, "malay", } }, - { "cs", { 24, "czech", } }, - { "ro", { 25, "romanian", } }, - { "da", { 26, "danish", } }, - { "hu", { 27, "hungarian", } }, - { "ta", { 28, "tamil", } }, - { "no", { 29, "norwegian", } }, - { "th", { 30, "thai", } }, - { "ur", { 31, "urdu", } }, - { "hr", { 32, "croatian", } }, - { "bg", { 33, "bulgarian", } }, - { "lt", { 34, "lithuanian", } }, - { "la", { 35, "latin", } }, - { "mi", { 36, "maori", } }, - { "ml", { 37, "malayalam", } }, - { "cy", { 38, "welsh", } }, - { "sk", { 39, "slovak", } }, - { "te", { 40, "telugu", } }, - { "fa", { 41, "persian", } }, - { "lv", { 42, "latvian", } }, - { "bn", { 43, "bengali", } }, - { "sr", { 44, "serbian", } }, - { "az", { 45, "azerbaijani", } }, - { "sl", { 46, "slovenian", } }, - { "kn", { 47, "kannada", } }, - { "et", { 48, "estonian", } }, - { "mk", { 49, "macedonian", } }, - { "br", { 50, "breton", } }, - { "eu", { 51, "basque", } }, - { "is", { 52, "icelandic", } }, - { "hy", { 53, "armenian", } }, - { "ne", { 54, "nepali", } }, - { "mn", { 55, "mongolian", } }, - { "bs", { 56, "bosnian", } }, - { "kk", { 57, "kazakh", } }, - { "sq", { 58, "albanian", } }, - { "sw", { 59, "swahili", } }, - { "gl", { 60, "galician", } }, - { "mr", { 61, "marathi", } }, - { "pa", { 62, "punjabi", } }, - { "si", { 63, "sinhala", } }, - { "km", { 64, "khmer", } }, - { "sn", { 65, "shona", } }, - { "yo", { 66, "yoruba", } }, - { "so", { 67, "somali", } }, - { "af", { 68, "afrikaans", } }, - { "oc", { 69, "occitan", } }, - { "ka", { 70, "georgian", } }, - { "be", { 71, "belarusian", } }, - { "tg", { 72, "tajik", } }, - { "sd", { 73, "sindhi", } }, - { "gu", { 74, "gujarati", } }, - { "am", { 75, "amharic", } }, - { "yi", { 76, "yiddish", } }, - { "lo", { 77, "lao", } }, - { "uz", { 78, "uzbek", } }, - { "fo", { 79, "faroese", } }, - { "ht", { 80, "haitian creole", } }, - { "ps", { 81, "pashto", } }, - { "tk", { 82, "turkmen", } }, - { "nn", { 83, "nynorsk", } }, - { "mt", { 84, "maltese", } }, - { "sa", { 85, "sanskrit", } }, - { "lb", { 86, "luxembourgish", } }, - { "my", { 87, "myanmar", } }, - { "bo", { 88, "tibetan", } }, - { "tl", { 89, "tagalog", } }, - { "mg", { 90, "malagasy", } }, - { "as", { 91, "assamese", } }, - { "tt", { 92, "tatar", } }, - { "haw", { 93, "hawaiian", } }, - { "ln", { 94, "lingala", } }, - { "ha", { 95, "hausa", } }, - { "ba", { 96, "bashkir", } }, - { "jw", { 97, "javanese", } }, - { "su", { 98, "sundanese", } }, -}; - -static const size_t MB = 1024*1024; - -static const std::map MEM_REQ_SCRATCH0 = { - { MODEL_TINY, 12ull*MB }, - { MODEL_BASE, 15ull*MB }, - { MODEL_SMALL, 23ull*MB }, - { MODEL_MEDIUM, 31ull*MB }, - { MODEL_LARGE, 38ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH1 = { - { MODEL_TINY, 18ull*MB }, - { MODEL_BASE, 24ull*MB }, - { MODEL_SMALL, 36ull*MB }, - { MODEL_MEDIUM, 48ull*MB }, - { MODEL_LARGE, 60ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH2 = { - { MODEL_TINY, 4ull*MB }, - { MODEL_BASE, 4ull*MB }, - { MODEL_SMALL, 6ull*MB }, - { MODEL_MEDIUM, 7ull*MB }, - { MODEL_LARGE, 9ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH3 = { - { MODEL_TINY, 4ull*MB }, - { MODEL_BASE, 4ull*MB }, - { MODEL_SMALL, 6ull*MB }, - { MODEL_MEDIUM, 7ull*MB }, - { MODEL_LARGE, 9ull*MB }, -}; - -static const std::map MEM_REQ_MODEL = { - { MODEL_TINY, 74ull*MB }, - { MODEL_BASE, 142ull*MB }, - { MODEL_SMALL, 466ull*MB }, - { MODEL_MEDIUM, 1464ull*MB }, - { MODEL_LARGE, 2952ull*MB }, -}; - -static const std::map MEM_REQ_KV_SELF = { - { MODEL_TINY, 3ull*MB }, - { MODEL_BASE, 6ull*MB }, - { MODEL_SMALL, 16ull*MB }, - { MODEL_MEDIUM, 43ull*MB }, - { MODEL_LARGE, 71ull*MB }, -}; - -static const std::map MEM_REQ_KV_CROSS = { - { MODEL_TINY, 9ull*MB }, - { MODEL_BASE, 18ull*MB }, - { MODEL_SMALL, 53ull*MB }, - { MODEL_MEDIUM, 141ull*MB }, - { MODEL_LARGE, 235ull*MB }, -}; - -static const std::map MEM_REQ_ENCODE = { - { MODEL_TINY, 6ull*MB }, - { MODEL_BASE, 8ull*MB }, - { MODEL_SMALL, 13ull*MB }, - { MODEL_MEDIUM, 22ull*MB }, - { MODEL_LARGE, 33ull*MB }, -}; - -static const std::map MEM_REQ_DECODE = { - { MODEL_TINY, 3ull*MB }, - { MODEL_BASE, 5ull*MB }, - { MODEL_SMALL, 10ull*MB }, - { MODEL_MEDIUM, 18ull*MB }, - { MODEL_LARGE, 27ull*MB }, -}; - -struct whisper_mel { - int n_len; - int n_mel; - - std::vector data; -}; - -struct whisper_filters { - int32_t n_mel; - int32_t n_fft; - - std::vector data; -}; - -struct whisper_vocab { - using id = int32_t; - using token = std::string; - - int n_vocab = 51864; - - std::map token_to_id; - std::map id_to_token; - - id token_eot = 50256; - id token_sot = 50257; - id token_prev = 50360; - id token_solm = 50361; // ?? - id token_not = 50362; // no timestamps - id token_beg = 50363; - - // available tasks - static const id token_translate = 50358; - static const id token_transcribe = 50359; - - bool is_multilingual() const { - return n_vocab == 51865; - } -}; - -struct whisper_segment { - int64_t t0; - int64_t t1; - - std::string text; - - std::vector tokens; -}; - -// medium -// hparams: { -// 'n_mels': 80, -// 'n_vocab': 51864, -// 'n_audio_ctx': 1500, -// 'n_audio_state': 1024, -// 'n_audio_head': 16, -// 'n_audio_layer': 24, -// 'n_text_ctx': 448, -// 'n_text_state': 1024, -// 'n_text_head': 16, -// 'n_text_layer': 24 -// } -// -// default hparams (Whisper tiny) -struct whisper_hparams { - int32_t n_vocab = 51864; - int32_t n_audio_ctx = 1500; - int32_t n_audio_state = 384; - int32_t n_audio_head = 6; - int32_t n_audio_layer = 4; - int32_t n_text_ctx = 448; - int32_t n_text_state = 384; - int32_t n_text_head = 6; - int32_t n_text_layer = 4; - int32_t n_mels = 80; - int32_t f16 = 1; -}; - -// audio encoding layer -struct whisper_layer_encoder { - // encoder.blocks.*.attn_ln - struct ggml_tensor * attn_ln_0_w; - struct ggml_tensor * attn_ln_0_b; - - // encoder.blocks.*.attn.out - struct ggml_tensor * attn_ln_1_w; - struct ggml_tensor * attn_ln_1_b; - - // encoder.blocks.*.attn.query - struct ggml_tensor * attn_q_w; - struct ggml_tensor * attn_q_b; - - // encoder.blocks.*.attn.key - struct ggml_tensor * attn_k_w; - - // encoder.blocks.*.attn.value - struct ggml_tensor * attn_v_w; - struct ggml_tensor * attn_v_b; - - // encoder.blocks.*.mlp_ln - struct ggml_tensor * mlp_ln_w; - struct ggml_tensor * mlp_ln_b; - - // encoder.blocks.*.mlp.0 - struct ggml_tensor * mlp_0_w; - struct ggml_tensor * mlp_0_b; - - // encoder.blocks.*.mlp.2 - struct ggml_tensor * mlp_1_w; - struct ggml_tensor * mlp_1_b; -}; - -// token decoding layer -struct whisper_layer_decoder { - // decoder.blocks.*.attn_ln - struct ggml_tensor * attn_ln_0_w; - struct ggml_tensor * attn_ln_0_b; - - // decoder.blocks.*.attn.out - struct ggml_tensor * attn_ln_1_w; - struct ggml_tensor * attn_ln_1_b; - - // decoder.blocks.*.attn.query - struct ggml_tensor * attn_q_w; - struct ggml_tensor * attn_q_b; - - // decoder.blocks.*.attn.key - struct ggml_tensor * attn_k_w; - - // decoder.blocks.*.attn.value - struct ggml_tensor * attn_v_w; - struct ggml_tensor * attn_v_b; - - // decoder.blocks.*.cross_attn_ln - struct ggml_tensor * cross_attn_ln_0_w; - struct ggml_tensor * cross_attn_ln_0_b; - - // decoder.blocks.*.cross_attn.out - struct ggml_tensor * cross_attn_ln_1_w; - struct ggml_tensor * cross_attn_ln_1_b; - - // decoder.blocks.*.cross_attn.query - struct ggml_tensor * cross_attn_q_w; - struct ggml_tensor * cross_attn_q_b; - - // decoder.blocks.*.cross_attn.key - struct ggml_tensor * cross_attn_k_w; - - // decoder.blocks.*.cross_attn.value - struct ggml_tensor * cross_attn_v_w; - struct ggml_tensor * cross_attn_v_b; - - // decoder.blocks.*.mlp_ln - struct ggml_tensor * mlp_ln_w; - struct ggml_tensor * mlp_ln_b; - - // decoder.blocks.*.mlp.0 - struct ggml_tensor * mlp_0_w; - struct ggml_tensor * mlp_0_b; - - // decoder.blocks.*.mlp.2 - struct ggml_tensor * mlp_1_w; - struct ggml_tensor * mlp_1_b; -}; - -struct whisper_kv_cache { - struct ggml_tensor * k; - struct ggml_tensor * v; - - struct ggml_context * ctx; - - std::vector buf; - - int n; // number of tokens currently in the cache -}; - -struct whisper_model { - e_model type = MODEL_UNKNOWN; - - whisper_hparams hparams; - whisper_filters filters; - - // encoder.positional_embedding - struct ggml_tensor * e_pe; - - // encoder.conv1 - struct ggml_tensor * e_conv_1_w; - struct ggml_tensor * e_conv_1_b; - - // encoder.conv2 - struct ggml_tensor * e_conv_2_w; - struct ggml_tensor * e_conv_2_b; - - // encoder.ln_post - struct ggml_tensor * e_ln_w; - struct ggml_tensor * e_ln_b; - - // decoder.positional_embedding - struct ggml_tensor * d_pe; - - // decoder.token_embedding - struct ggml_tensor * d_te; - - // decoder.ln - struct ggml_tensor * d_ln_w; - struct ggml_tensor * d_ln_b; - - std::vector layers_encoder; - std::vector layers_decoder; - - // context - struct ggml_context * ctx; - - // the model memory buffer is read-only and can be shared between processors - std::vector * buf; - - // tensors - int n_loaded; - std::map tensors; -}; - -struct whisper_sequence { - std::vector tokens; - - // the accumulated transcription in the current interation (used to truncate the tokens array) - int result_len; - - double sum_logprobs_all; // the sum of the log probabilities of the tokens - double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) - double avg_logprobs; // the average log probability of the tokens - double entropy; // the entropy of the tokens - double score; // likelihood rank score -}; - -// TAGS: WHISPER_DECODER_INIT -struct whisper_decoder { - // each decoders keeps its own KV-cache - whisper_kv_cache kv_self; - - // the currently generated sequence of tokens - whisper_sequence sequence; - - int seek_delta; // the window shift found so far based on the decoded timestamp tokens - - bool failed; // has the current segment failed to decode? - bool completed; // has the decoder completed the current segment? - bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? - - // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) - std::vector probs; - std::vector logits; - std::vector logprobs; - - std::vector tokens_tmp; // used for whisper_decode calls -}; - -struct whisper_context { - int64_t t_load_us = 0; - int64_t t_mel_us = 0; - int64_t t_sample_us = 0; - int64_t t_encode_us = 0; - int64_t t_decode_us = 0; - int64_t t_start_us = 0; - - int32_t n_sample = 0; // number of tokens sampled - int32_t n_encode = 0; // number of encoder calls - int32_t n_decode = 0; // number of decoder calls - int32_t n_fail_p = 0; // number of logprob threshold failures - int32_t n_fail_h = 0; // number of entropy threshold failures - - ggml_type wtype; // weight type (FP32 or FP16) - - whisper_mel mel; - - whisper_model model; - whisper_vocab vocab; - - // cross-attention KV cache for the decoders - // shared between all decoders - whisper_kv_cache kv_cross; - - whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; - - // memory buffers used by encode / decode contexts - std::vector buf_compute; - std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; - - int buf_last = 0; - size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; - - // decode output (2-dimensional array: [n_tokens][n_vocab]) - std::vector logits; - - std::vector result_all; - std::vector prompt_past; - - // work container used to avoid memory allocations - std::vector> logits_id; - - mutable std::mt19937 rng; // used for sampling at t > 0.0 - - int lang_id; - - // [EXPERIMENTAL] token-level timestamps data - int64_t t_beg; - int64_t t_last; - whisper_token tid_last; - std::vector energy; // PCM signal energy - - // [EXPERIMENTAL] speed-up techniques - int32_t exp_n_audio_ctx; // 0 - use default - - void use_buf(struct ggml_context * ctx, int i) { -#if defined(WHISPER_USE_SCRATCH) - size_t last_size = 0; - - if (i == -1) { - last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); - } else { - auto & buf = buf_scratch[i]; - last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); - } - - if (buf_last >= 0) { - buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); - } - - buf_last = i; -#else - (void) i; - (void) ctx; -#endif - } - - size_t get_buf_max_mem(int i) const { -#if defined(WHISPER_USE_SCRATCH) - return buf_max_size[i]; -#else - (void) i; - return 0; -#endif - } -}; - -template -static void read_safe(whisper_model_loader * loader, T & dest) { - loader->read(loader->context, &dest, sizeof(T)); - BYTESWAP_VALUE(dest); -} - -static bool kv_cache_init( - const struct whisper_hparams & hparams, - const size_t mem_bytes, - struct whisper_kv_cache & cache, - ggml_type wtype, - int n_ctx) { - cache.buf.resize(mem_bytes); - - struct ggml_init_params params; - params.mem_size = cache.buf.size(); - params.mem_buffer = cache.buf.data(); - - cache.ctx = ggml_init(params); - - if (!cache.ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); - return false; - } - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - - const int n_mem = n_text_layer*n_ctx; - const int n_elements = n_text_state*n_mem; - - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - - return true; -} - -static bool kv_cache_reinit(struct whisper_kv_cache & cache) { - WHISPER_ASSERT(cache.ctx); - - const int n_elements = ggml_nelements(cache.k); - WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); - - const ggml_type wtype = cache.k->type; - WHISPER_ASSERT(wtype == cache.v->type); - - WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); - - struct ggml_init_params params; - params.mem_size = cache.buf.size(); - params.mem_buffer = cache.buf.data(); - - cache.ctx = ggml_init(params); - - if (!cache.ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); - return false; - } - - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - - return true; -} - -static void kv_cache_free(struct whisper_kv_cache & cache) { - if (cache.ctx) { - ggml_free(cache.ctx); - cache.ctx = nullptr; - } -} - -// load the model from a ggml file -// -// file format: -// -// - hparams -// - pre-computed mel filters -// - vocab -// - weights -// -// see the convert-pt-to-ggml.py script for details -// -static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { - fprintf(stderr, "%s: loading model\n", __func__); - - const int64_t t_start_us = ggml_time_us(); - - wctx.t_start_us = t_start_us; - - auto & model = wctx.model; - auto & vocab = wctx.vocab; - - // verify magic - { - uint32_t magic; - read_safe(loader, magic); - if (magic != 0x67676d6c) { - fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__); - return false; - } - } - - //load hparams - { - auto & hparams = model.hparams; - - read_safe(loader, hparams.n_vocab); - read_safe(loader, hparams.n_audio_ctx); - read_safe(loader, hparams.n_audio_state); - read_safe(loader, hparams.n_audio_head); - read_safe(loader, hparams.n_audio_layer); - read_safe(loader, hparams.n_text_ctx); - read_safe(loader, hparams.n_text_state); - read_safe(loader, hparams.n_text_head); - read_safe(loader, hparams.n_text_layer); - read_safe(loader, hparams.n_mels); - read_safe(loader, hparams.f16); - - assert(hparams.n_text_state == hparams.n_audio_state); - - if (hparams.n_audio_layer == 4) { - model.type = e_model::MODEL_TINY; - } - - if (hparams.n_audio_layer == 6) { - model.type = e_model::MODEL_BASE; - } - - if (hparams.n_audio_layer == 12) { - model.type = e_model::MODEL_SMALL; - } - - if (hparams.n_audio_layer == 24) { - model.type = e_model::MODEL_MEDIUM; - } - - if (hparams.n_audio_layer == 32) { - model.type = e_model::MODEL_LARGE; - } - - // for the big tensors, we have the option to store the data in 16-bit floats - // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; - - const size_t scale = model.hparams.f16 ? 1 : 2; - - fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); - fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); - fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); - fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); - fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); - fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); - fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state); - fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); - fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); - fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); - fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); - fprintf(stderr, "%s: type = %d\n", __func__, model.type); - - // print memory requirements - { - // this is the total memory required to run the inference - const size_t mem_required = - MEM_REQ_SCRATCH0.at (model.type) + - MEM_REQ_SCRATCH1.at (model.type) + - MEM_REQ_SCRATCH2.at (model.type) + - MEM_REQ_SCRATCH3.at (model.type) + - scale*MEM_REQ_MODEL.at (model.type) + - scale*MEM_REQ_KV_CROSS.at(model.type) + - scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); - - // this is the memory required by one decoder - const size_t mem_required_decoder = - scale*MEM_REQ_KV_SELF.at(model.type); - - fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); - } - - // initialize all memory buffers - // always have at least one decoder - - wctx.model.buf = new std::vector(); - wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); - fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - - wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); - wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); - wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); - wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); - } - - // load mel filters - { - auto & filters = wctx.model.filters; - - read_safe(loader, filters.n_mel); - read_safe(loader, filters.n_fft); - - filters.data.resize(filters.n_mel * filters.n_fft); - loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); - BYTESWAP_FILTERS(filters); - } - - // load vocab - { - int32_t n_vocab = 0; - read_safe(loader, n_vocab); - - //if (n_vocab != model.hparams.n_vocab) { - // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", - // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); - // return false; - //} - - std::string word; - std::vector tmp; - - tmp.reserve(128); - - for (int i = 0; i < n_vocab; i++) { - uint32_t len; - read_safe(loader, len); - - if (len > 0) { - tmp.resize(len); - loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer - word.assign(&tmp[0], tmp.size()); - } else { - // seems like we have an empty-string token in multi-language models (i = 50256) - //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); - word = ""; - } - - vocab.token_to_id[word] = i; - vocab.id_to_token[i] = word; - - //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); - } - - vocab.n_vocab = model.hparams.n_vocab; - if (vocab.is_multilingual()) { - vocab.token_eot++; - vocab.token_sot++; - vocab.token_prev++; - vocab.token_solm++; - vocab.token_not++; - vocab.token_beg++; - } - - if (n_vocab < model.hparams.n_vocab) { - fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); - for (int i = n_vocab; i < model.hparams.n_vocab; i++) { - if (i > vocab.token_beg) { - word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; - } else if (i == vocab.token_eot) { - word = "[_EOT_]"; - } else if (i == vocab.token_sot) { - word = "[_SOT_]"; - } else if (i == vocab.token_prev) { - word = "[_PREV_]"; - } else if (i == vocab.token_not) { - word = "[_NOT_]"; - } else if (i == vocab.token_beg) { - word = "[_BEG_]"; - } else { - word = "[_extra_token_" + std::to_string(i) + "]"; - } - vocab.token_to_id[word] = i; - vocab.id_to_token[i] = word; - } - } - - wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - - wctx.logits_id.reserve(n_vocab); - - // TAGS: WHISPER_DECODER_INIT - wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); - - wctx.decoders[0].probs.reserve (vocab.n_vocab); - wctx.decoders[0].logits.reserve (vocab.n_vocab); - wctx.decoders[0].logprobs.reserve(vocab.n_vocab); - } - - size_t ctx_size = 0; - - const ggml_type wtype = wctx.wtype; - - { - const auto & hparams = model.hparams; - - const int n_vocab = hparams.n_vocab; - - const int n_audio_ctx = hparams.n_audio_ctx; - const int n_audio_state = hparams.n_audio_state; - const int n_audio_layer = hparams.n_audio_layer; - - const int n_text_ctx = hparams.n_text_ctx; - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - - const int n_mels = hparams.n_mels; - - // encoder - { - ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; - - ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b - - ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b - - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; - } - - // decoder - { - ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; - - ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; - - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; - } - - // encoder layers - { - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b - } - - // decoder layers - { - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b - // - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b - } - - ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead - - fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); - } - - // create the ggml context - { - struct ggml_init_params params; - params.mem_size = wctx.model.buf->size(); - params.mem_buffer = wctx.model.buf->data(); - - model.ctx = ggml_init(params); - if (!model.ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - } - - // prepare memory for the weights - { - auto & ctx = model.ctx; - - const auto & hparams = model.hparams; - - const int n_vocab = hparams.n_vocab; - - const int n_audio_ctx = hparams.n_audio_ctx; - const int n_audio_state = hparams.n_audio_state; - const int n_audio_layer = hparams.n_audio_layer; - - const int n_text_ctx = hparams.n_text_ctx; - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - - const int n_mels = hparams.n_mels; - - model.layers_encoder.resize(n_audio_layer); - model.layers_decoder.resize(n_text_layer); - - // encoder - { - model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); - - model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); - model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); - - model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); - model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); - - model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - // map by name - model.tensors["encoder.positional_embedding"] = model.e_pe; - - model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; - model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; - - model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; - model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; - - model.tensors["encoder.ln_post.weight"] = model.e_ln_w; - model.tensors["encoder.ln_post.bias"] = model.e_ln_b; - - for (int i = 0; i < n_audio_layer; ++i) { - auto & layer = model.layers_encoder[i]; - - layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); - layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state); - - layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); - layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - - layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - // map by name - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; - } - } - - // decoder - { - model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx); - - model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); - - model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - // map by name - model.tensors["decoder.positional_embedding"] = model.d_pe; - - model.tensors["decoder.token_embedding.weight"] = model.d_te; - - model.tensors["decoder.ln.weight"] = model.d_ln_w; - model.tensors["decoder.ln.bias"] = model.d_ln_b; - - for (int i = 0; i < n_text_layer; ++i) { - auto & layer = model.layers_decoder[i]; - - layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); - layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state); - - layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); - layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - - layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - - layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - // map by name - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; - } - } - } - - // load weights - { - size_t total_size = 0; - - model.n_loaded = 0; - - while (true) { - int32_t n_dims; - int32_t length; - int32_t ftype; - - read_safe(loader, n_dims); - read_safe(loader, length); - read_safe(loader, ftype); - - if (loader->eof(loader->context)) { - break; - } - - int32_t nelements = 1; - int32_t ne[3] = { 1, 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - read_safe(loader, ne[i]); - nelements *= ne[i]; - } - - std::string name; - std::vector tmp(length); // create a buffer - loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer - name.assign(&tmp[0], tmp.size()); - - if (model.tensors.find(name) == model.tensors.end()) { - fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); - return false; - } - - auto tensor = model.tensors[name.data()]; - if (ggml_nelements(tensor) != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; - } - - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]); - return false; - } - - const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); - - if (nelements*bpe != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; - } - - loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); - BYTESWAP_TENSOR(tensor); - - //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); - total_size += ggml_nbytes(tensor); - model.n_loaded++; - } - - fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); - - if (model.n_loaded == 0) { - fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); - } else if (model.n_loaded != (int) model.tensors.size()) { - fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); - return false; - } - } - - wctx.rng = std::mt19937(0); - - wctx.t_load_us = ggml_time_us() - t_start_us; - - return true; -} - -// evaluate the encoder -// -// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder -// part of the transformer model and returns the encoded features -// -// - model: the model -// - n_threads: number of threads to use -// - mel_offset: offset in the mel spectrogram (i.e. audio offset) -// -static bool whisper_encode( - whisper_context & wctx, - const int mel_offset, - const int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - const auto & model = wctx.model; - const auto & mel_inp = wctx.mel; - const auto & hparams = model.hparams; - - const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; - const int n_state = hparams.n_audio_state; - const int n_head = hparams.n_audio_head; - const int n_layer = hparams.n_audio_layer; - - const int n_mels = hparams.n_mels; - assert(mel_inp.n_mel == n_mels); - - struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); - - struct ggml_context * ctx0 = ggml_init(params); - - wctx.use_buf(ctx0, 0); - - struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); - assert(mel->type == GGML_TYPE_F32); - { - float * dst = (float *) mel->data; - memset(dst, 0, ggml_nbytes(mel)); - - const int i0 = std::min(mel_offset, mel_inp.n_len); - const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); - - for (int j = 0; j < mel_inp.n_mel; ++j) { - for (int i = i0; i < i1; ++i) { - dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; - } - } - } - - struct ggml_tensor * cur; - - // convolution + gelu - { - wctx.use_buf(ctx0, 1); - - cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); - - cur = ggml_gelu(ctx0, cur); - - wctx.use_buf(ctx0, 0); - - cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); - - cur = ggml_gelu(ctx0, cur); - } - - wctx.use_buf(ctx0, 3); - - // =================================================================== - // NOTE: experimenting with partial evaluation of the encoder (ignore) - //static int iter = -1; - //const int n_iter = 1500/n_ctx; - - //iter = (iter + 1) % n_iter; - - //if (iter == 0) { - // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); - // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); - //} - - static int iter = 0; - - const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); - const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; - - struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); - - cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); - - // =================================================================== - - // original: - //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); - - struct ggml_tensor * inpL = cur; - - for (int il = 0; il < n_layer; ++il) { - const auto & layer = model.layers_encoder[il]; - - // norm - { - wctx.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpL); - - // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); - } - - // self-attention - { - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); - - Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_q_b, - Qcur), - Qcur); - - //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); - - //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); - - Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); - - // ------ - - wctx.use_buf(ctx0, 0); - -#ifdef WHISPER_USE_FLASH_ATTN - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) - ); - - struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); -#else - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - ); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); - - //struct ggml_tensor * V_trans = - // ggml_permute(ctx0, - // ggml_cpy(ctx0, - // Vcur, - // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), - // 1, 2, 0, 3); - - //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 0, 2, 1, 3), - ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) - ); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); -#endif - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - wctx.use_buf(ctx0, 1); - - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); - } - - // projection - { - wctx.use_buf(ctx0, 0); - - cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); - - wctx.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.attn_ln_1_b, cur), - cur); - } - - wctx.use_buf(ctx0, 2); - - // add the input - cur = ggml_add(ctx0, cur, inpL); - - struct ggml_tensor * inpFF = cur; - - // feed-forward network - { - // norm - { - wctx.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpFF); - - wctx.use_buf(ctx0, 1); - - // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctx0, layer.mlp_ln_b, cur)); - } - -#ifdef WHISPER_USE_FLASH_FF - wctx.use_buf(ctx0, 0); - - cur = ggml_flash_ff(ctx0, - ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), - layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); -#else - wctx.use_buf(ctx0, 0); - - // fully connected - cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); - - wctx.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_0_b, cur), - cur); - - wctx.use_buf(ctx0, 0); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - - wctx.use_buf(ctx0, 1); - - // projection - cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); - - wctx.use_buf(ctx0, 0); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_1_b, cur), - cur); -#endif - } - - wctx.use_buf(ctx0, 3); - - inpL = ggml_add(ctx0, cur, inpFF); - } - - cur = inpL; - - // norm - { - wctx.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, cur); - - wctx.use_buf(ctx0, 1); - - // cur = ln_f_g*cur + ln_f_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.e_ln_w, cur), - cur), - ggml_repeat(ctx0, model.e_ln_b, cur)); - } - - wctx.use_buf(ctx0, -1); - - // run the computation - { - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - - ggml_build_forward_expand(&gf, cur); - ggml_graph_compute (ctx0, &gf); - - //ggml_graph_print(&gf); - } - - // cur - //{ - // printf("ne0 = %d\n", cur->ne[0]); - // printf("ne1 = %d\n", cur->ne[1]); - // for (int i = 0; i < 10; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("... "); - // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("\n"); - //} - - // pre-compute cross-attention memory - { - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - - // TODO: hack to disconnect the encoded features from the previous graph - cur->op = GGML_OP_NONE; - cur->src0 = nullptr; - cur->src1 = nullptr; - - for (int il = 0; il < model.hparams.n_text_layer; ++il) { - auto & layer = model.layers_decoder[il]; - - wctx.use_buf(ctx0, 0); - - struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, - layer.cross_attn_k_w, - cur); - - Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, - layer.cross_attn_v_w, - cur); - - Vcross = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_v_b, - Vcross), - Vcross); - - wctx.use_buf(ctx0, -1); - - //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); - - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); - } - - ggml_graph_compute(ctx0, &gf); - //ggml_graph_print(&gf); - } - - //////////////////////////////////////////////////////////////////////////// - - //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1024.0/1024.0, - // wctx.get_buf_max_mem(0)/1024.0/1024.0, - // wctx.get_buf_max_mem(1)/1024.0/1024.0, - // wctx.get_buf_max_mem(2)/1024.0/1024.0, - // wctx.get_buf_max_mem(3)/1024.0/1024.0); - - ggml_free(ctx0); - - wctx.t_encode_us += ggml_time_us() - t_start_us; - wctx.n_encode++; - - return true; -} - -// evaluate the decoder -// -// given text prompt + audio features -> computes the logits for the next token -// -// - model: the model -// - n_threads: number of threads to use -// - tokens: text prompt -// - n_tokens: number of tokens in the prompt -// - n_past: number of past tokens to prefix the prompt with -// -static bool whisper_decode( - whisper_context & wctx, - whisper_decoder & decoder, - const whisper_token * tokens, - const int n_tokens, - const int n_past, - const int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - const auto & model = wctx.model; - const auto & hparams = model.hparams; - - auto & kv_self = decoder.kv_self; - - WHISPER_ASSERT(!!kv_self.ctx); - - auto & logits_out = wctx.logits; - - const int n_vocab = hparams.n_vocab; - - const int n_ctx = hparams.n_text_ctx; - const int n_state = hparams.n_text_state; - const int n_head = hparams.n_text_head; - const int n_layer = hparams.n_text_layer; - - const int N = n_tokens; - const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; - - //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); - - struct ggml_init_params params; - params.mem_size = wctx.buf_compute.size(); - params.mem_buffer = wctx.buf_compute.data(); - - struct ggml_context * ctx0 = ggml_init(params); - - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); - - struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; - } - - wctx.use_buf(ctx0, 3); - - // token encoding + position encoding - struct ggml_tensor * cur = - ggml_add(ctx0, - ggml_get_rows(ctx0, model.d_te, embd), - ggml_get_rows(ctx0, model.d_pe, position)); - - struct ggml_tensor * inpL = cur; - - for (int il = 0; il < n_layer; ++il) { - const auto & layer = model.layers_decoder[il]; - - // norm - { - wctx.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpL); - - // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); - } - - // self-attention - { - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); - - Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_q_b, - Qcur), - Qcur); - - Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); - - Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); - - Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); - - // store key and value to memory - { - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); - - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); - } - - // ------ - - wctx.use_buf(ctx0, 0); - - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), - n_state/n_head, n_head, n_past + N), - 0, 2, 1, 3); - - wctx.use_buf(ctx0, 1); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - wctx.use_buf(ctx0, 0); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - // ); - - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - - wctx.use_buf(ctx0, 0); - - struct ggml_tensor * V_trans = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), - n_state/n_head, n_head, n_past + N), - 1, 2, 0, 3); - - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); - } - - // projection - { - wctx.use_buf(ctx0, 0); - - cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); - - wctx.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.attn_ln_1_b, cur), - cur); - } - - wctx.use_buf(ctx0, 2); - - // add the input - struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); - - // norm - { - wctx.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here - - wctx.use_buf(ctx0, 1); - - // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); - } - - // cross-attention - { - wctx.use_buf(ctx0, 0); - - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.cross_attn_q_w, - cur); - - Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_q_b, - Qcur), - Qcur); - - Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - // Kcross is already scaled - struct ggml_tensor * Kcross = - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), - n_state/n_head, n_head, M); - - struct ggml_tensor * Vcross = - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), - n_state/n_head, n_head, M); - - struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); - - // ------ - - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), - 0, 2, 1, 3); - - struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); - - wctx.use_buf(ctx0, 0); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - // ); - - // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - - wctx.use_buf(ctx0, 0); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - - wctx.use_buf(ctx0, 1); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_state, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); - } - - // projection - { - wctx.use_buf(ctx0, 0); - - cur = ggml_mul_mat(ctx0, - layer.cross_attn_ln_1_w, - cur); - - wctx.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), - cur); - } - - wctx.use_buf(ctx0, 2); - - // add the input - cur = ggml_add(ctx0, cur, inpCA); - - struct ggml_tensor * inpFF = cur; - - // feed-forward network - { - // norm - { - wctx.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpFF); - - wctx.use_buf(ctx0, 1); - - // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctx0, layer.mlp_ln_b, cur)); - } - - wctx.use_buf(ctx0, 0); - - // fully connected - cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); - - wctx.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_0_b, cur), - cur); - - wctx.use_buf(ctx0, 0); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - - wctx.use_buf(ctx0, 1); - - // projection - cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); - - wctx.use_buf(ctx0, 0); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_1_b, cur), - cur); - } - - wctx.use_buf(ctx0, 3); - - inpL = ggml_add(ctx0, cur, inpFF); - } - - cur = inpL; - - // norm - { - wctx.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, cur); - - wctx.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.d_ln_w, cur), - cur), - ggml_repeat(ctx0, model.d_ln_b, cur)); - } - - wctx.use_buf(ctx0, 0); - - // compute logits only for the last token - // comment this line to compute logits for all N tokens - // might be useful in the future - cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); - - struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - - wctx.use_buf(ctx0, -1); - - // run the computation - { - ggml_build_forward_expand(&gf, logits); - ggml_graph_compute (ctx0, &gf); - } - - // extract logits for all N tokens - //logits_out.resize(N*n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); - - // extract logits only for the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); - - if (N > 1) { - //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1024.0/1024.0, - // wctx.get_buf_max_mem(0)/1024.0/1024.0, - // wctx.get_buf_max_mem(1)/1024.0/1024.0, - // wctx.get_buf_max_mem(2)/1024.0/1024.0, - // wctx.get_buf_max_mem(3)/1024.0/1024.0); - } - - ggml_free(ctx0); - - wctx.t_decode_us += ggml_time_us() - t_start_us; - wctx.n_decode++; - - return true; -} - -// 500 -> 00:05.000 -// 6000 -> 01:00.000 -static std::string to_timestamp(int64_t t, bool comma = false) { - int64_t msec = t * 10; - int64_t hr = msec / (1000 * 60 * 60); - msec = msec - hr * (1000 * 60 * 60); - int64_t min = msec / (1000 * 60); - msec = msec - min * (1000 * 60); - int64_t sec = msec / 1000; - msec = msec - sec * 1000; - - char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); - - return std::string(buf); -} - -// naive Discrete Fourier Transform -// input is real-valued -// output is complex-valued -static void dft(const std::vector & in, std::vector & out) { - int N = in.size(); - - out.resize(N*2); - - for (int k = 0; k < N; k++) { - float re = 0; - float im = 0; - - for (int n = 0; n < N; n++) { - float angle = 2*M_PI*k*n/N; - re += in[n]*cos(angle); - im -= in[n]*sin(angle); - } - - out[k*2 + 0] = re; - out[k*2 + 1] = im; - } -} - -// Cooley-Tukey FFT -// poor man's implementation - use something better -// input is real-valued -// output is complex-valued -static void fft(const std::vector & in, std::vector & out) { - out.resize(in.size()*2); - - int N = in.size(); - - if (N == 1) { - out[0] = in[0]; - out[1] = 0; - return; - } - - if (N%2 == 1) { - dft(in, out); - return; - } - - std::vector even; - std::vector odd; - - even.reserve(N/2); - odd.reserve(N/2); - - for (int i = 0; i < N; i++) { - if (i % 2 == 0) { - even.push_back(in[i]); - } else { - odd.push_back(in[i]); - } - } - - std::vector even_fft; - std::vector odd_fft; - - fft(even, even_fft); - fft(odd, odd_fft); - - for (int k = 0; k < N/2; k++) { - float theta = 2*M_PI*k/N; - - float re = cos(theta); - float im = -sin(theta); - - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; - - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; - - out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; - } -} - -// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 -static bool log_mel_spectrogram( - whisper_context & wctx, - const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int fft_size, - const int fft_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool speed_up, - whisper_mel & mel) { - const int64_t t_start_us = ggml_time_us(); - - // Hanning window - std::vector hann; - hann.resize(fft_size); - for (int i = 0; i < fft_size; i++) { - hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); - } - - mel.n_mel = n_mel; - mel.n_len = (n_samples)/fft_step; - mel.data.resize(mel.n_mel*mel.n_len); - - const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); - - //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); - //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); - - std::vector workers(n_threads); - for (int iw = 0; iw < n_threads; ++iw) { - workers[iw] = std::thread([&](int ith) { - std::vector fft_in; - fft_in.resize(fft_size); - for (int i = 0; i < fft_size; i++) { - fft_in[i] = 0.0; - } - - std::vector fft_out; - fft_out.resize(2*fft_size); - - for (int i = ith; i < mel.n_len; i += n_threads) { - const int offset = i*fft_step; - - // apply Hanning window - for (int j = 0; j < fft_size; j++) { - if (offset + j < n_samples) { - fft_in[j] = hann[j]*samples[offset + j]; - } else { - fft_in[j] = 0.0; - } - } - - // FFT -> mag^2 - fft(fft_in, fft_out); - - for (int j = 0; j < fft_size; j++) { - fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); - } - for (int j = 1; j < fft_size/2; j++) { - //if (i == 0) { - // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); - //} - fft_out[j] += fft_out[fft_size - j]; - } - if (i == 0) { - //for (int j = 0; j < fft_size; j++) { - // printf("%d: %e\n", j, fft_out[j]); - //} - } - - if (speed_up) { - // scale down in the frequency domain results in a speed up in the time domain - for (int j = 0; j < n_fft; j++) { - fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); - } - } - - // mel spectrogram - for (int j = 0; j < mel.n_mel; j++) { - double sum = 0.0; - - for (int k = 0; k < n_fft; k++) { - sum += fft_out[k]*filters.data[j*n_fft + k]; - } - if (sum < 1e-10) { - sum = 1e-10; - } - - sum = log10(sum); - - mel.data[j*mel.n_len + i] = sum; - } - } - }, iw); - } - - for (int iw = 0; iw < n_threads; ++iw) { - workers[iw].join(); - } - - // clamping and normalization - double mmax = -1e20; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] > mmax) { - mmax = mel.data[i]; - } - } - //printf("%s: max = %f\n", __func__, mmax); - - mmax -= 8.0; - - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] < mmax) { - mel.data[i] = mmax; - } - - mel.data[i] = (mel.data[i] + 4.0)/4.0; - } - - wctx.t_mel_us += ggml_time_us() - t_start_us; - - return true; -} - -// split text into tokens -// -// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 -// -// Regex (Python): -// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" -// -// Regex (C++): -// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" -// -static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { - std::vector words; - - // first split the text into words - { - std::string str = text; - std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; - - std::regex re(pat); - std::smatch m; - - while (std::regex_search(str, m, re)) { - for (auto x : m) { - words.push_back(x); - } - str = m.suffix(); - } - } - - // find the longest tokens that form the words: - std::vector tokens; - for (const auto & word : words) { - if (word.empty()) continue; - - int i = 0; - int n = word.size(); - while (i < n) { - int j = n; - while (j > i) { - auto it = vocab.token_to_id.find(word.substr(i, j-i)); - if (it != vocab.token_to_id.end()) { - tokens.push_back(it->second); - i = j; - break; - } - --j; - } - if (i == n) { - break; - } - if (j == i) { - auto sub = word.substr(i, 1); - if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { - tokens.push_back(vocab.token_to_id.at(sub)); - } else { - fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); - } - ++i; - } - } - } - - return tokens; -} - -// -// interface implementation -// - -struct whisper_context * whisper_init_from_file(const char * path_model) { - whisper_model_loader loader = {}; - - fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); - - auto fin = std::ifstream(path_model, std::ios::binary); - if (!fin) { - fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model); - return nullptr; - } - - loader.context = &fin; - loader.read = [](void * ctx, void * output, size_t read_size) { - std::ifstream * fin = (std::ifstream*)ctx; - fin->read((char *)output, read_size); - return read_size; - }; - - loader.eof = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; - return fin->eof(); - }; - - loader.close = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; - fin->close(); - }; - - return whisper_init(&loader); -} - -struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { - struct buf_context { - uint8_t* buffer; - size_t size; - size_t current_offset; - }; - - buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; - whisper_model_loader loader = {}; - - fprintf(stderr, "%s: loading model from buffer\n", __func__); - - loader.context = &ctx; - - loader.read = [](void * ctx, void * output, size_t read_size) { - buf_context * buf = reinterpret_cast(ctx); - - size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; - - memcpy(output, buf->buffer + buf->current_offset, size_to_copy); - buf->current_offset += size_to_copy; - - return size_to_copy; - }; - - loader.eof = [](void * ctx) { - buf_context * buf = reinterpret_cast(ctx); - - return buf->current_offset >= buf->size; - }; - - loader.close = [](void * /*ctx*/) { }; - - return whisper_init(&loader); -} - -struct whisper_context * whisper_init(struct whisper_model_loader * loader) { - ggml_time_init(); - - whisper_context * ctx = new whisper_context; - - if (!whisper_model_load(loader, *ctx)) { - loader->close(loader->context); - fprintf(stderr, "%s: failed to load model\n", __func__); - delete ctx; - return nullptr; - } - - loader->close(loader->context); - - return ctx; -} - -void whisper_free(struct whisper_context * ctx) { - if (ctx) { - if (ctx->model.ctx) { - ggml_free(ctx->model.ctx); - } - if (ctx->model.buf) { - delete ctx->model.buf; - } - if (ctx->kv_cross.ctx) { - ggml_free(ctx->kv_cross.ctx); - } - for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { - if (ctx->decoders[i].kv_self.ctx) { - ggml_free(ctx->decoders[i].kv_self.ctx); - } - } - delete ctx; - } -} - -int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { - fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); - return -1; - } - - return 0; -} - -// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 -int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { - fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); - return -1; - } - - return 0; -} - -int whisper_set_mel( - struct whisper_context * ctx, - const float * data, - int n_len, - int n_mel) { - if (n_mel != WHISPER_N_MEL) { - fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); - return -1; - } - - ctx->mel.n_len = n_len; - ctx->mel.n_mel = n_mel; - - ctx->mel.data.resize(n_len*n_mel); - memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); - - return 0; -} - -int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - if (!whisper_encode(*ctx, offset, n_threads)) { - fprintf(stderr, "%s: failed to eval\n", __func__); - return -1; - } - - return 0; -} - -int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - // TODO: add selected_decoder_id to context - const int selected_decoder_id = 0; - - if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { - fprintf(stderr, "%s: failed to eval\n", __func__); - return 1; - } - - return 0; -} - -int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { - const auto res = tokenize(ctx->vocab, text); - - if (n_max_tokens < (int) res.size()) { - fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); - return -1; - } - - for (int i = 0; i < (int) res.size(); i++) { - tokens[i] = res[i]; - } - - return res.size(); -} - -int whisper_lang_max_id() { - auto max_id = 0; - for (const auto & kv : g_lang) { - max_id = std::max(max_id, kv.second.first); - } - - return max_id; -} - -int whisper_lang_id(const char * lang) { - if (!g_lang.count(lang)) { - for (const auto & kv : g_lang) { - if (kv.second.second == lang) { - return kv.second.first; - } - } - - fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); - return -1; - } - - return g_lang.at(lang).first; -} - -const char * whisper_lang_str(int id) { - for (const auto & kv : g_lang) { - if (kv.second.first == id) { - return kv.first.c_str(); - } - } - - fprintf(stderr, "%s: unknown language id %d\n", __func__, id); - return nullptr; -} - -int whisper_lang_auto_detect( - struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs) { - const int seek = offset_ms/10; - - if (seek < 0) { - fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms); - return -1; - } - - if (seek >= ctx->mel.n_len) { - fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); - return -2; - } - - // run the encoder - if (whisper_encode(ctx, seek, n_threads) != 0) { - fprintf(stderr, "%s: failed to encode\n", __func__); - return -6; - } - - const std::vector prompt = { whisper_token_sot(ctx) }; - - if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { - fprintf(stderr, "%s: failed to decode\n", __func__); - return -7; - } - - auto & logits_id = ctx->logits_id; - logits_id.clear(); - - for (const auto & kv : g_lang) { - const auto token_lang = whisper_token_lang(ctx, kv.second.first); - logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); - } - - // sort descending - { - using pair_type = std::remove_reference::type::value_type; - std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); - } - - // softmax - { - const auto max = logits_id[0].first; - - double sum = 0.0f; - for (auto & kv : logits_id) { - kv.first = exp(kv.first - max); - sum += kv.first; - } - - for (auto & kv : logits_id) { - kv.first /= sum; - } - } - - { - for (const auto & prob : logits_id) { - if (lang_probs) { - lang_probs[prob.second] = prob.first; - } - - //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); - } - } - - return logits_id[0].second; -} - -int whisper_n_len(struct whisper_context * ctx) { - return ctx->mel.n_len; -} - -int whisper_n_vocab(struct whisper_context * ctx) { - return ctx->vocab.n_vocab; -} - -int whisper_n_text_ctx(struct whisper_context * ctx) { - return ctx->model.hparams.n_text_ctx; -} - -int whisper_n_audio_ctx(struct whisper_context * ctx) { - return ctx->model.hparams.n_audio_ctx; -} - -int whisper_is_multilingual(struct whisper_context * ctx) { - return ctx->vocab.is_multilingual() ? 1 : 0; -} - -float * whisper_get_logits(struct whisper_context * ctx) { - return ctx->logits.data(); -} - -const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { - return ctx->vocab.id_to_token.at(token).c_str(); -} - -whisper_token whisper_token_eot(struct whisper_context * ctx) { - return ctx->vocab.token_eot; -} - -whisper_token whisper_token_sot(struct whisper_context * ctx) { - return ctx->vocab.token_sot; -} - -whisper_token whisper_token_prev(struct whisper_context * ctx) { - return ctx->vocab.token_prev; -} - -whisper_token whisper_token_solm(struct whisper_context * ctx) { - return ctx->vocab.token_solm; -} - -whisper_token whisper_token_not(struct whisper_context * ctx) { - return ctx->vocab.token_not; -} - -whisper_token whisper_token_beg(struct whisper_context * ctx) { - return ctx->vocab.token_beg; -} - -whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { - return whisper_token_sot(ctx) + 1 + lang_id; -} - -whisper_token whisper_token_translate(void) { - return whisper_vocab::token_translate; -} - -whisper_token whisper_token_transcribe(void) { - return whisper_vocab::token_transcribe; -} - -void whisper_print_timings(struct whisper_context * ctx) { - const int64_t t_end_us = ggml_time_us(); - - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_encode = std::max(1, ctx->n_encode); - const int32_t n_decode = std::max(1, ctx->n_decode); - - fprintf(stderr, "\n"); - fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); - fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample); - fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode); - fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); -} - -void whisper_reset_timings(struct whisper_context * ctx) { - ctx->t_sample_us = 0; - ctx->t_encode_us = 0; - ctx->t_decode_us = 0; -} - -const char * whisper_print_system_info(void) { - static std::string s; - - s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; - s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; - - return s.c_str(); -} - -//////////////////////////////////////////////////////////////////////////// - -struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { - struct whisper_full_params result = { - /*.strategy =*/ strategy, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.split_on_word =*/ false, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - - /*.suppress_blank =*/ true, - /*.suppress_non_speech_tokens =*/true, - - /*.temperature =*/ 0.0f, - /*.max_initial_ts =*/ 1.0f, - /*.length_penalty =*/ -1.0f, - - /*.temperature_inc =*/ 0.2f, - /*.entropy_thold =*/ 2.4f, - /*.logprob_thold =*/ -1.0f, - /*.no_speech_thold =*/ 0.6f, - - /*.greedy =*/ { - /*.best_of =*/ -1, - }, - - /*.beam_search =*/ { - /*.beam_size =*/ -1, - - /*.patience =*/ -1.0f, - }, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, - }; - - switch (strategy) { - case WHISPER_SAMPLING_GREEDY: - { - result.greedy = { - /*.best_of =*/ 1, - }; - } break; - case WHISPER_SAMPLING_BEAM_SEARCH: - { - result.beam_search = { - /*.beam_size =*/ 5, - - /*.patience =*/ -1.0f, - }; - } break; - } - - return result; -} - -// forward declarations -static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); -static void whisper_exp_compute_token_level_timestamps( - struct whisper_context & ctx, - int i_segment, - float thold_pt, - float thold_ptsum); - -// trim from start (in place) -static inline void ltrim(std::string &s) { - s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { - return !std::isspace(ch); - })); -} - -// trim from end (in place) -static inline void rtrim(std::string &s) { - s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { - return !std::isspace(ch); - }).base(), s.end()); -} - -// trim from both ends (in place) -static inline void trim(std::string &s) { - rtrim(s); - ltrim(s); -} - -static inline bool should_split_on_word(const char * txt, bool split_on_word) { - if (!split_on_word) return true; - - return txt[0] == ' '; -} - -// wrap the last segment to max_len characters -// returns the number of new segments -static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) { - auto segment = ctx.result_all.back(); - - int res = 1; - int acc = 0; - - std::string text; - - for (int i = 0; i < (int) segment.tokens.size(); i++) { - const auto & token = segment.tokens[i]; - if (token.id >= whisper_token_eot(&ctx)) { - continue; - } - - const auto txt = whisper_token_to_str(&ctx, token.id); - const int cur = strlen(txt); - - if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { - // split here - if (split_on_word) { - trim(text); - } - - ctx.result_all.back().text = std::move(text); - ctx.result_all.back().t1 = token.t0; - ctx.result_all.back().tokens.resize(i); - - ctx.result_all.push_back({}); - ctx.result_all.back().t0 = token.t0; - ctx.result_all.back().t1 = segment.t1; - - // add tokens [i, end] to the new segment - ctx.result_all.back().tokens.insert( - ctx.result_all.back().tokens.end(), - segment.tokens.begin() + i, - segment.tokens.end()); - - acc = 0; - text = ""; - - segment = ctx.result_all.back(); - i = -1; - - res++; - } else { - acc += cur; - text += txt; - } - } - - if (split_on_word) { - trim(text); - } - ctx.result_all.back().text = std::move(text); - - return res; -} - -static const std::vector non_speech_tokens -{ - "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", - "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", - "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", - "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" -}; - -// process the logits for the selected decoder -// - applies logit filters -// - computes logprobs and probs -static void whisper_process_logits( - const struct whisper_context & ctx, - const struct whisper_full_params params, - struct whisper_decoder & decoder, - float temperature) { - const auto & vocab = ctx.vocab; - const auto & tokens_cur = decoder.sequence.tokens; - - const bool is_initial = tokens_cur.size() == 0; - const int n_logits = vocab.id_to_token.size(); - - WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); - - // extract the logits for the last token - // we will be mutating and therefore we don't want to use the ctx.logits buffer directly - auto & probs = decoder.probs; - auto & logits = decoder.logits; - auto & logprobs = decoder.logprobs; - { - logits.resize(n_logits); - memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); - - if (temperature > 0.0f) { - for (int i = 0; i < n_logits; i++) { - logits[i] /= temperature; - } - } - - // will be populated a bit later - probs.resize(n_logits); - logprobs.resize(n_logits); - } - - // apply logit filters here - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 - { - // suppress blank - // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 - if (params.suppress_blank) { - if (is_initial) { - logits[vocab.token_eot] = -INFINITY; - logits[vocab.token_to_id.at(" ")] = -INFINITY; - } - } - - // suppress <|notimestamps|> token - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 - logits[vocab.token_not] = -INFINITY; - - // suppress sot and solm tokens - logits[vocab.token_sot] = -INFINITY; - logits[vocab.token_solm] = -INFINITY; - - // suppress task tokens - logits[vocab.token_translate] = -INFINITY; - logits[vocab.token_transcribe] = -INFINITY; - - - // suppress non-speech tokens - // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - if (params.suppress_non_speech_tokens) - { - for (const std::string &token : non_speech_tokens) - { - std::string suppress_tokens[] = {token, " " + token}; - for (const std::string &suppress_token : suppress_tokens) - { - if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) - { - logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; - } - } - } - // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) - { - logits[vocab.token_to_id.at(" -")] = -INFINITY; - } - if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) - { - logits[vocab.token_to_id.at(" '")] = -INFINITY; - } - } - - // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly - // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 - { - const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; - const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; - - //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); - - if (last_was_timestamp) { - if (penultimate_was_timestamp) { - for (int i = vocab.token_beg; i < n_logits; ++i) { - logits[i] = -INFINITY; - } - } else { - for (int i = 0; i < vocab.token_eot; ++i) { - logits[i] = -INFINITY; - } - } - } - } - - // the initial timestamp cannot be larger than max_initial_ts - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 - if (is_initial && params.max_initial_ts > 0.0f) { - const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; - const int tid0 = std::round(params.max_initial_ts/precision); - - for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { - logits[i] = -INFINITY; - } - } - - // condition timestamp tokens to be increasing - // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 - if (decoder.has_ts) { - const int tid0 = decoder.seek_delta/2; - - for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) { - logits[i] = -INFINITY; - } - } - - // populate the logprobs array (log_softmax) - { - const float logit_max = *std::max_element(logits.begin(), logits.end()); - float logsumexp = 0.0f; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logsumexp += expf(logits[i] - logit_max); - } - } - logsumexp = logf(logsumexp) + logit_max; - - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logprobs[i] = logits[i] - logsumexp; - } else { - logprobs[i] = -INFINITY; - } - } - } - - // if sum of probability over timestamps is above any other token, sample timestamp - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 - { - // logsumexp over timestamps - float timestamp_logprob = -INFINITY; - { - float logsumexp = 0.0f; - const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); - for (int i = vocab.token_beg; i < n_logits; ++i) { - if (logprobs[i] > -INFINITY) { - logsumexp += expf(logprobs[i] - logprob_max); - } - } - if (logsumexp > 0.0f) { - timestamp_logprob = logf(logsumexp) + logprob_max; - } - } - - const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); - - //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); - - if (timestamp_logprob > max_text_token_logprob) { - for (int i = 0; i < vocab.token_beg; ++i) { - logits[i] = -INFINITY; - logprobs[i] = -INFINITY; - } - } - } - } - - // compute probs - { - for (int i = 0; i < n_logits; ++i) { - if (logits[i] == -INFINITY) { - probs[i] = 0.0f; - } else { - probs[i] = expf(logprobs[i]); - } - } - } - -#if 0 - // print first 100 logits - token string : logit - for (int i = 0; i < 100; i++) { - const auto token = vocab.id_to_token.at(i); - const auto prob = probs[i]; - const auto logit = logits[i]; - const auto logprob = logprobs[i]; - printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); - } - - // "And", "and", " And", " and" - printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); - printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); - printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); - printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); - printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); - - printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); - printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); - printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); - printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); - printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); - - printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); - printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); - printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); - printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); - printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); -#endif -} - -static whisper_token_data whisper_sample_token( - whisper_context & ctx, - const whisper_decoder & decoder, - bool best) { - whisper_token_data result = { - 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, - }; - - const auto & vocab = ctx.vocab; - - const auto & probs = decoder.probs; - const auto & logprobs = decoder.logprobs; - - const int n_logits = vocab.n_vocab; - - { - double sum_ts = 0.0; - double max_ts = 0.0; - - for (int i = vocab.token_beg; i < n_logits; i++) { - if (probs[i] == -INFINITY) { - continue; - } - - sum_ts += probs[i]; - if (max_ts < probs[i]) { - max_ts = probs[i]; - result.tid = i; - } - } - - result.pt = max_ts/(sum_ts + 1e-10); - result.ptsum = sum_ts; - } - - if (best) { - for (int i = 0; i < n_logits; ++i) { - if (result.p < probs[i]) { - result.id = i; - result.p = probs[i]; - result.plog = logprobs[i]; - } - } - } else { - std::discrete_distribution<> dist(probs.begin(), probs.end()); - - result.id = dist(ctx.rng); - result.p = probs[result.id]; - result.plog = logprobs[result.id]; - } - - if (result.id >= vocab.token_beg) { - result.tid = result.id; - result.pt = result.p; - } - - ctx.n_sample++; - - return result; -} - -static std::vector whisper_sample_token_topk( - whisper_context & ctx, - const whisper_decoder & decoder, - int k) { - const auto & vocab = ctx.vocab; - - const auto & probs = decoder.probs; - const auto & logits = decoder.logits; - const auto & logprobs = decoder.logprobs; - - const int n_logits = vocab.n_vocab; - - auto & logits_id = ctx.logits_id; - - logits_id.clear(); - for (int i = 0; i < n_logits; ++i) { - logits_id.push_back({ logits[i], i }); - } - - std::partial_sort( - logits_id.begin(), - logits_id.begin() + k, logits_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - std::vector result; - result.reserve(k); - - whisper_token tid = vocab.token_beg; - - float pt = 0.0; - float ptsum = 0.0; - - { - double sum_ts = 0.0; - double max_ts = 0.0; - - for (int i = vocab.token_beg; i < n_logits; i++) { - if (probs[i] == -INFINITY) { - continue; - } - - sum_ts += probs[i]; - if (max_ts < probs[i]) { - max_ts = probs[i]; - tid = i; - } - } - - pt = max_ts/(sum_ts + 1e-10); - ptsum = sum_ts; - } - - for (int i = 0; i < k; ++i) { - const auto id = logits_id[i].second; - - result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); - - if (result[i].id >= vocab.token_beg) { - result[i].tid = result[i].id; - result[i].pt = result[i].p; - } - } - - ctx.n_sample++; - - return result; -} - -// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 -static void whisper_sequence_score( - const struct whisper_full_params & params, - whisper_sequence & sequence) { - if (sequence.result_len == 0) { - return; - } - - double result = 0.0f; - - for (int i = 0; i < sequence.result_len; ++i) { - result += sequence.tokens[i].plog; - } - - sequence.sum_logprobs = result; - sequence.avg_logprobs = result/sequence.result_len; - - double penalty = sequence.result_len; - - if (params.length_penalty > 0.0f) { - penalty = pow((5.0 + penalty)/6.0, params.length_penalty); - } - - sequence.score = result/penalty; - - // compute the entropy of the sequence of the last 32 tokens - { - const int n = 32; - - int cnt = 0; - double entropy = 0.0f; - - std::map token_counts; - for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { - token_counts[sequence.tokens[i].id]++; - cnt++; - } - - for (const auto & kv : token_counts) { - const auto p = kv.second/(double)cnt; - entropy -= p*log(p); - - //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); - } - - sequence.entropy = entropy; - } -} - -int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples) { - // clear old results - auto & result_all = ctx->result_all; - - result_all.clear(); - - // compute log mel spectrogram - if (params.speed_up) { - if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -1; - } - } else { - if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -2; - } - } - - // auto-detect language if not specified - if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { - std::vector probs(whisper_lang_max_id() + 1, 0.0f); - - const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); - if (lang_id < 0) { - fprintf(stderr, "%s: failed to auto-detect language\n", __func__); - return -3; - } - ctx->lang_id = lang_id; - params.language = whisper_lang_str(lang_id); - - fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); - } - - if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; - ctx->tid_last = 0; - ctx->energy = get_signal_energy(samples, n_samples, 32); - } - - const int seek_start = params.offset_ms/10; - const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); - - // if length of spectrogram is less than 1s (100 samples), then return - // basically don't process anything that is less than 1s - // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 - if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { - return 0; - } - - // a set of temperatures to use - // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] - std::vector temperatures; - if (params.temperature_inc > 0.0f) { - for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { - temperatures.push_back(t); - } - } else { - temperatures.push_back(params.temperature); - } - - // initialize the decoders - int n_decoders = 1; - - switch (params.strategy) { - case WHISPER_SAMPLING_GREEDY: - { - n_decoders = params.greedy.best_of; - } break; - case WHISPER_SAMPLING_BEAM_SEARCH: - { - n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); - } break; - }; - - n_decoders = std::max(1, n_decoders); - - // TAGS: WHISPER_DECODER_INIT - for (int j = 1; j < n_decoders; j++) { - auto & decoder = ctx->decoders[j]; - - if (decoder.kv_self.ctx == nullptr) { - decoder.kv_self = ctx->decoders[0].kv_self; - if (!kv_cache_reinit(decoder.kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); - return -4; - } - - WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); - - decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); - - decoder.probs.resize (ctx->vocab.n_vocab); - decoder.logits.resize (ctx->vocab.n_vocab); - decoder.logprobs.resize(ctx->vocab.n_vocab); - } - } - - // the accumulated text context so far - auto & prompt_past = ctx->prompt_past; - if (params.no_context) { - prompt_past.clear(); - } - - // prepend the prompt tokens to the prompt_past - if (params.prompt_tokens && params.prompt_n_tokens > 0) { - // parse tokens from the pointer - for (int i = 0; i < params.prompt_n_tokens; i++) { - prompt_past.push_back(params.prompt_tokens[i]); - } - std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); - } - - // overwrite audio_ctx, max allowed is hparams.n_audio_ctx - if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { - fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); - return -5; - } - ctx->exp_n_audio_ctx = params.audio_ctx; - - // these tokens determine the task that will be performed - std::vector prompt_init = { whisper_token_sot(ctx) }; - if (whisper_is_multilingual(ctx)) { - const int lang_id = whisper_lang_id(params.language); - ctx->lang_id = lang_id; - prompt_init.push_back(whisper_token_lang(ctx, lang_id)); - if (params.translate) { - prompt_init.push_back(whisper_token_translate()); - } else { - prompt_init.push_back(whisper_token_transcribe()); - } - } - - int progress_prev = 0; - int progress_step = 5; - - int seek = seek_start; - - std::vector prompt; - prompt.reserve(whisper_n_text_ctx(ctx)); - - // beam-search helpers - struct kv_buf { - std::vector k; - std::vector v; - }; - - std::vector kv_bufs; - - struct beam_candidate { - int decoder_idx; - int seek_delta; - - bool has_ts; - - whisper_sequence sequence; - }; - - std::vector beam_candidates; - - // main loop - while (true) { - const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); - while (progress_cur >= progress_prev + progress_step) { - progress_prev += progress_step; - if (params.print_progress) { - fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev); - } - } - - // of only 1 second left, then stop - if (seek + 100 >= seek_end) { - break; - } - - if (params.encoder_begin_callback) { - if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { - fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); - break; - } - } - - // encode audio features starting at offset seek - if (!whisper_encode(*ctx, seek, params.n_threads)) { - fprintf(stderr, "%s: failed to encode\n", __func__); - return -6; - } - - // if there is a very short audio segment left to process, we remove any past prompt since it tends - // to confuse the decoder and often make it repeat or hallucinate stuff - if (seek > seek_start && seek + 500 >= seek_end) { - prompt_past.clear(); - } - - int best_decoder_id = 0; - - for (int it = 0; it < (int) temperatures.size(); ++it) { - const float t_cur = temperatures[it]; - - int n_decoders_cur = 1; - - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur > 0.0f) { - n_decoders_cur = params.greedy.best_of; - } - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - if (t_cur > 0.0f) { - n_decoders_cur = params.greedy.best_of; - } else { - n_decoders_cur = params.beam_search.beam_size; - } - } break; - }; - - n_decoders_cur = std::max(1, n_decoders_cur); - - WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); - - // TAGS: WHISPER_DECODER_INIT - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - decoder.kv_self.n = 0; - - decoder.sequence.tokens.clear(); - decoder.sequence.result_len = 0; - decoder.sequence.sum_logprobs_all = 0.0; - decoder.sequence.sum_logprobs = -INFINITY; - decoder.sequence.avg_logprobs = -INFINITY; - decoder.sequence.entropy = 0.0; - decoder.sequence.score = -INFINITY; - - decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; - - decoder.failed = false; - decoder.completed = false; - decoder.has_ts = false; - } - - // init prompt and kv cache for the current iteration - // run whisper_decoder() only for decoder 0 and copy the results for the other decoders - { - prompt.clear(); - - // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); - - prompt = { whisper_token_prev(ctx) }; - prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); - } - - // init new transcription with sot, language (opt) and task tokens - prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - - // print the prompt - WHISPER_PRINT_DEBUG("\n\n"); - for (int i = 0; i < (int) prompt.size(); i++) { - WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); - } - WHISPER_PRINT_DEBUG("\n\n"); - - if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { - fprintf(stderr, "%s: failed to decode\n", __func__); - return -7; - } - - { - const int64_t t_start_sample_us = ggml_time_us(); - - whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); - - ctx->decoders[0].kv_self.n += prompt.size(); - - for (int j = 1; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); - - decoder.kv_self.n += prompt.size(); - - memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); - memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); - } - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - } - - for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - const int64_t t_start_sample_us = ggml_time_us(); - - // store the KV caches of all decoders when doing beam-search - if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { - kv_bufs.resize(n_decoders_cur); - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - if (decoder.completed || decoder.failed) { - continue; - } - - kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k)); - kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v)); - - memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); - memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); - } - - beam_candidates.clear(); - } - - // generate new sequence candidates for each decoder - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - if (decoder.completed || decoder.failed) { - continue; - } - - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); - } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); - } - - decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); - - for (const auto & token : tokens_new) { - beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); - beam_candidates.back().sequence.tokens.push_back(token); - beam_candidates.back().sequence.sum_logprobs_all += token.plog; - - //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); - } - } break; - }; - } - - // for beam-search, choose the top candidates and update the KV caches - if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { - std::sort( - beam_candidates.begin(), - beam_candidates.end(), - [](const beam_candidate & a, const beam_candidate & b) { - return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; - }); - - int cur_c = 0; - - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - if (decoder.completed || decoder.failed) { - continue; - } - - auto & cur = beam_candidates[cur_c++]; - - while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { - ++cur_c; - } - - decoder.sequence = cur.sequence; - decoder.seek_delta = cur.seek_delta; - decoder.has_ts = cur.has_ts; - - memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); - memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); - - WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", - __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); - } - } - - // update the decoder state - // - check if the sequence is completed - // - check if the sequence is failed - // - update sliding window based on timestamp tokens - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - if (decoder.completed || decoder.failed) { - continue; - } - - auto & has_ts = decoder.has_ts; - auto & failed = decoder.failed; - auto & completed = decoder.completed; - auto & seek_delta = decoder.seek_delta; - auto & result_len = decoder.sequence.result_len; - - { - const auto & token = decoder.sequence.tokens.back(); - - // timestamp token - update sliding window - if (token.id > whisper_token_beg(ctx)) { - const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); - - // do not allow to go back in time - if (has_ts && seek_delta > seek_delta_new && result_len < i) { - failed = true; // TODO: maybe this is not a failure ? - continue; - } - - seek_delta = seek_delta_new; - result_len = i + 1; - has_ts = true; - } - -#ifdef WHISPER_DEBUG - { - const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; - WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", - __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); - } -#endif - - // end of segment - if (token.id == whisper_token_eot(ctx) || // end of text token - (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached - ) { - if (result_len == 0) { - if (seek + seek_delta + 100 >= seek_end) { - result_len = i + 1; - } else { - failed = true; - continue; - } - } - - if (params.single_segment) { - result_len = i + 1; - seek_delta = 100*WHISPER_CHUNK_SIZE; - } - - completed = true; - continue; - } - - // TESTS: if no tensors are loaded, it means we are running tests - if (ctx->model.n_loaded == 0) { - seek_delta = 100*WHISPER_CHUNK_SIZE; - completed = true; - continue; - } - } - - // sometimes, the decoding can get stuck in a repetition loop - // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy - if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { - failed = true; - continue; - } - } - - // check if all decoders have finished (i.e. completed or failed) - { - bool completed_all = true; - - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - if (decoder.completed || decoder.failed) { - continue; - } - - completed_all = false; - } - - if (completed_all) { - break; - } - } - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - - // obtain logits for the next token - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - if (decoder.failed || decoder.completed) { - continue; - } - - decoder.tokens_tmp.resize(1); - decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; - - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - - if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { - fprintf(stderr, "%s: failed to decode\n", __func__); - return -8; - } - - { - const int64_t t_start_sample_us = ggml_time_us(); - - whisper_process_logits(*ctx, params, decoder, t_cur); - - ++decoder.kv_self.n; - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - } - } - - // rank the resulting sequences and select the best one - { - double best_score = -INFINITY; - - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; - - if (decoder.failed) { - continue; - } - - decoder.sequence.tokens.resize(decoder.sequence.result_len); - whisper_sequence_score(params, decoder.sequence); - - WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", - __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); - - if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { - WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", - __func__, j, decoder.sequence.entropy, params.entropy_thold); - - decoder.failed = true; - ctx->n_fail_h++; - - continue; - } - - if (best_score < decoder.sequence.score) { - best_score = decoder.sequence.score; - best_decoder_id = j; - } - } - - WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); - } - - // was the decoding successful for the current temperature? - { - bool success = true; - - const auto & decoder = ctx->decoders[best_decoder_id]; - - if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { - success = false; - ctx->n_fail_p++; - } - - if (success) { - //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { - // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); - //} - - break; - } - } - - WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); - } - - // output results through a user-provided callback - { - const auto & best_decoder = ctx->decoders[best_decoder_id]; - - const auto seek_delta = best_decoder.seek_delta; - const auto result_len = best_decoder.sequence.result_len; - - const auto & tokens_cur = best_decoder.sequence.tokens; - - //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); - - // update prompt_past - prompt_past.clear(); - if (prompt.front() == whisper_token_prev(ctx)) { - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); - } - - for (int i = 0; i < result_len; ++i) { - prompt_past.push_back(tokens_cur[i].id); - } - - // store the text from this iteration - if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { - int i0 = 0; - auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); - - std::string text; - - for (int i = 0; i < (int) tokens_cur.size(); i++) { - //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, - // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, - // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); - - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { - } else { - text += whisper_token_to_str(ctx, tokens_cur[i].id); - } - - if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { - const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); - - if (!text.empty()) { - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; - - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); - } - } - - //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); - - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j <= i; j++) { - result_all.back().tokens.push_back(tokens_cur[j]); - } - - int n_new = 1; - - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - - if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); - } - } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); - } - } - text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { - i++; - } - i--; - t0 = t1; - i0 = i + 1; - } - } - - if (!text.empty()) { - const auto t1 = seek + seek_delta; - - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; - - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); - } - } - - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j < (int) tokens_cur.size(); j++) { - result_all.back().tokens.push_back(tokens_cur[j]); - } - - int n_new = 1; - - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - - if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); - } - } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); - } - } - } - - // update audio window - seek += seek_delta; - - WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); - } - } - - return 0; -} - -int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - int n_processors) { - if (n_processors == 1) { - return whisper_full(ctx, params, samples, n_samples); - } - - int ret = 0; - - // prepare separate contexts for each thread - std::vector ctxs(n_processors - 1); - - for (int i = 0; i < n_processors - 1; ++i) { - auto & ctx_p = ctxs[i]; - - ctx_p = *ctx; - - ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - - ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); - - if (!kv_cache_reinit(ctx_p.kv_cross)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); - return false; - } - - // TAGS: WHISPER_DECODER_INIT - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); - return false; - } - - ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); - - ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); - ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); - } - } - - const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; - const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; - - // the calling thread will process the first chunk - // while the other threads will process the remaining chunks - - std::vector workers(n_processors - 1); - for (int i = 0; i < n_processors - 1; ++i) { - const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; - const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; - - auto params_cur = params; - - params_cur.offset_ms = 0; - params_cur.print_progress = false; - params_cur.print_realtime = false; - - params_cur.new_segment_callback = nullptr; - params_cur.new_segment_callback_user_data = nullptr; - - workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); - } - - { - auto params_cur = params; - - ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); - } - - for (int i = 0; i < n_processors - 1; ++i) { - workers[i].join(); - } - - const int64_t offset_t = (int64_t) params.offset_ms/10.0; - - // combine results into ctx->result_all - for (int i = 0; i < n_processors - 1; ++i) { - auto & results_i = ctxs[i].result_all; - - for (auto & result : results_i) { - // correct the segment timestamp taking into account the offset - result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; - result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; - - // make sure that segments are not overlapping - if (!ctx->result_all.empty()) { - result.t0 = std::max(result.t0, ctx->result_all.back().t1); - } - - ctx->result_all.push_back(std::move(result)); - - // call the new_segment_callback for each segment - if (params.new_segment_callback) { - params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); - } - } - - ctx->t_mel_us += ctxs[i].t_mel_us; - ctx->t_sample_us += ctxs[i].t_sample_us; - ctx->t_encode_us += ctxs[i].t_encode_us; - ctx->t_decode_us += ctxs[i].t_decode_us; - - kv_cache_free(ctx->kv_cross); - - for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { - kv_cache_free(ctx->decoders[j].kv_self); - } - } - - // average the timings - ctx->t_mel_us /= n_processors; - ctx->t_sample_us /= n_processors; - ctx->t_encode_us /= n_processors; - ctx->t_decode_us /= n_processors; - - // print information about the audio boundaries - fprintf(stderr, "\n"); - fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); - for (int i = 0; i < n_processors - 1; ++i) { - fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); - } - fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); - - return ret; -} - -int whisper_full_n_segments(struct whisper_context * ctx) { - return ctx->result_all.size(); -} - -int whisper_full_lang_id(struct whisper_context * ctx) { - return ctx->lang_id; -} - -int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t0; -} - -int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].t1; -} - -const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].text.c_str(); -} - -int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { - return ctx->result_all[i_segment].tokens.size(); -} - -const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); -} - -whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].id; -} - -struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token]; -} - -float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { - return ctx->result_all[i_segment].tokens[i_token].p; -} - -// ================================================================================================= - -// -// Temporary interface needed for exposing ggml interface -// Will be removed in the future when ggml becomes a separate library -// - -WHISPER_API int whisper_bench_memcpy(int n_threads) { - ggml_time_init(); - - size_t n = 50; - size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations - - // 1 GB array - const size_t size = arr*1024llu*1024llu; - - char * src = (char *) malloc(size); - char * dst = (char *) malloc(size); - - for (size_t i = 0; i < size; i++) src[i] = i; - - memcpy(dst, src, size); // heat-up - - double tsum = 0.0; - - for (size_t i = 0; i < n; i++) { - const int64_t t0 = ggml_time_us(); - - memcpy(dst, src, size); - - const int64_t t1 = ggml_time_us(); - - tsum += (t1 - t0)*1e-6; - - src[0] = rand(); - } - - fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); - - // needed to prevent the compile from optimizing the memcpy away - { - double sum = 0.0; - - for (size_t i = 0; i < size; i++) sum += dst[i]; - - fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); - } - - free(src); - free(dst); - - return 0; -} - -WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { - ggml_time_init(); - - const int n_max = 128; - - const std::vector sizes = { - 64, 128, 256, 512, 1024, 2048, 4096, - }; - - const size_t N_max = sizes.back(); - - // a: N*N*sizeof(float) - // b: N*N*sizeof(float) - // c: N*N*sizeof(float) - // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) - std::vector buf(4llu*N_max*N_max*sizeof(float) + 4*256); - - for (size_t i = 0; i < buf.size(); i++) buf[i] = i; - - for (int j = 0; j < (int) sizes.size(); j++) { - int n_fp16 = 0; - int n_fp32 = 0; - - // GFLOPS/s - double s_fp16 = 0.0; - double s_fp32 = 0.0; - - const size_t N = sizes[j]; - - for (int k = 0; k < 2; ++k) { - const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32; - - double & s = k == 0 ? s_fp16 : s_fp32; - int & n = k == 0 ? n_fp16 : n_fp32; - - struct ggml_init_params gparams = { - /*.mem_size =*/ buf.size(), - /*.mem_buffer =*/ buf.data(), - }; - - struct ggml_context * ctx0 = ggml_init(gparams); - - struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); - struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); - - struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); - - struct ggml_cgraph gf = ggml_build_forward(c); - - gf.n_threads = n_threads; - - double tsum = 0.0; - - // heat-up - ggml_graph_compute(ctx0, &gf); - - for (int i = 0; i < n_max; ++i) { - const int64_t t0 = ggml_time_us(); - - ggml_graph_compute(ctx0, &gf); - - const int64_t t1 = ggml_time_us(); - - tsum += (t1 - t0)*1e-6; - n++; - - if (tsum > 1.0 && n >= 3) { - break; - } - } - - ggml_free(ctx0); - - s = ((2.0*N*N*N*n)/tsum)*1e-9; - } - - fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", - N, N, s_fp16, n_fp16, s_fp32, n_fp32); - } - - return 0; -} - -// ================================================================================================= - -// ================================================================================================= - -// -// Experimental stuff below -// -// Not sure if these should be part of the library at all, because the quality of the results is not -// guaranteed. Might get removed at some point unless a robust algorithm implementation is found -// - -// ================================================================================================= - -// -// token-level timestamps -// - -static int timestamp_to_sample(int64_t t, int n_samples) { - return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); -} - -static int64_t sample_to_timestamp(int i_sample) { - return (100ll*i_sample)/WHISPER_SAMPLE_RATE; -} - -// a cost-function / heuristic that is high for text that takes longer to pronounce -// obviously, can be improved -static float voice_length(const std::string & text) { - float res = 0.0f; - - for (char c : text) { - if (c == ' ') { - res += 0.01f; - } else if (c == ',') { - res += 2.00f; - } else if (c == '.') { - res += 3.00f; - } else if (c == '!') { - res += 3.00f; - } else if (c == '?') { - res += 3.00f; - } else if (c >= '0' && c <= '9') { - res += 3.00f; - } else { - res += 1.00f; - } - } - - return res; -} - -// average the fabs of the signal -static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { - const int hw = n_samples_per_half_window; - - std::vector result(n_samples); - - for (int i = 0; i < n_samples; i++) { - float sum = 0; - for (int j = -hw; j <= hw; j++) { - if (i + j >= 0 && i + j < n_samples) { - sum += fabs(signal[i + j]); - } - } - result[i] = sum/(2*hw + 1); - } - - return result; -} - -static void whisper_exp_compute_token_level_timestamps( - struct whisper_context & ctx, - int i_segment, - float thold_pt, - float thold_ptsum) { - auto & segment = ctx.result_all[i_segment]; - auto & tokens = segment.tokens; - - const int n_samples = ctx.energy.size(); - - if (n_samples == 0) { - fprintf(stderr, "%s: no signal data available\n", __func__); - return; - } - - const int64_t t0 = segment.t0; - const int64_t t1 = segment.t1; - - const int n = tokens.size(); - - if (n == 0) { - return; - } - - if (n == 1) { - tokens[0].t0 = t0; - tokens[0].t1 = t1; - - return; - } - - auto & t_beg = ctx.t_beg; - auto & t_last = ctx.t_last; - auto & tid_last = ctx.tid_last; - - for (int j = 0; j < n; ++j) { - auto & token = tokens[j]; - - if (j == 0) { - if (token.id == whisper_token_beg(&ctx)) { - tokens[j ].t0 = t0; - tokens[j ].t1 = t0; - tokens[j + 1].t0 = t0; - - t_beg = t0; - t_last = t0; - tid_last = whisper_token_beg(&ctx); - } else { - tokens[j ].t0 = t_last; - } - } - - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); - - tokens[j].id = token.id; - tokens[j].tid = token.tid; - tokens[j].p = token.p; - tokens[j].pt = token.pt; - tokens[j].ptsum = token.ptsum; - - tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); - - if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { - if (j > 0) { - tokens[j - 1].t1 = tt; - } - tokens[j].t0 = tt; - tid_last = token.tid; - } - } - - tokens[n - 2].t1 = t1; - tokens[n - 1].t0 = t1; - tokens[n - 1].t1 = t1; - - t_last = t1; - - // find intervals of tokens with unknown timestamps - // fill the timestamps by proportionally splitting the interval based on the token voice lengths - { - int p0 = 0; - int p1 = 0; - - while (true) { - while (p1 < n && tokens[p1].t1 < 0) { - p1++; - } - - if (p1 >= n) { - p1--; - } - - //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); - - if (p1 > p0) { - double psum = 0.0; - for (int j = p0; j <= p1; j++) { - psum += tokens[j].vlen; - } - - //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); - - const double dt = tokens[p1].t1 - tokens[p0].t0; - - // split the time proportionally to the voice length - for (int j = p0 + 1; j <= p1; j++) { - const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; - - tokens[j - 1].t1 = ct; - tokens[j ].t0 = ct; - } - } - - p1++; - p0 = p1; - if (p1 >= n) { - break; - } - } - } - - // fix up (just in case) - for (int j = 0; j < n - 1; j++) { - if (tokens[j].t1 < 0) { - tokens[j + 1].t0 = tokens[j].t1; - } - - if (j > 0) { - if (tokens[j - 1].t1 > tokens[j].t0) { - tokens[j].t0 = tokens[j - 1].t1; - tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); - } - } - } - - // VAD - // expand or contract tokens based on voice activity - { - const int hw = WHISPER_SAMPLE_RATE/8; - - for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(&ctx)) { - continue; - } - - int s0 = timestamp_to_sample(tokens[j].t0, n_samples); - int s1 = timestamp_to_sample(tokens[j].t1, n_samples); - - const int ss0 = std::max(s0 - hw, 0); - const int ss1 = std::min(s1 + hw, n_samples); - - const int ns = ss1 - ss0; - - float sum = 0.0f; - - for (int k = ss0; k < ss1; k++) { - sum += ctx.energy[k]; - } - - const float thold = 0.5*sum/ns; - - { - int k = s0; - if (ctx.energy[k] > thold && j > 0) { - while (k > 0 && ctx.energy[k] > thold) { - k--; - } - tokens[j].t0 = sample_to_timestamp(k); - if (tokens[j].t0 < tokens[j - 1].t1) { - tokens[j].t0 = tokens[j - 1].t1; - } else { - s0 = k; - } - } else { - while (ctx.energy[k] < thold && k < s1) { - k++; - } - s0 = k; - tokens[j].t0 = sample_to_timestamp(k); - } - } - - { - int k = s1; - if (ctx.energy[k] > thold) { - while (k < n_samples - 1 && ctx.energy[k] > thold) { - k++; - } - tokens[j].t1 = sample_to_timestamp(k); - if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { - tokens[j].t1 = tokens[j + 1].t0; - } else { - s1 = k; - } - } else { - while (ctx.energy[k] < thold && k > s0) { - k--; - } - s1 = k; - tokens[j].t1 = sample_to_timestamp(k); - } - } - } - } - - // fixed token expand (optional) - //{ - // const int t_expand = 0; - - // for (int j = 0; j < n; j++) { - // if (j > 0) { - // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); - // } - // if (j < n - 1) { - // tokens[j].t1 = tokens[j].t1 + t_expand; - // } - // } - //} - - // debug info - //for (int j = 0; j < n; ++j) { - // const auto & token = tokens[j]; - // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; - // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); - - // if (tokens[j].id >= whisper_token_eot(&ctx)) { - // continue; - // } - //} -} diff --git a/bindings/ruby/ext/whisper.h b/bindings/ruby/ext/whisper.h deleted file mode 100644 index 7eece797c16..00000000000 --- a/bindings/ruby/ext/whisper.h +++ /dev/null @@ -1,379 +0,0 @@ -#ifndef WHISPER_H -#define WHISPER_H - -#include -#include -#include - -#ifdef WHISPER_SHARED -# ifdef _WIN32 -# ifdef WHISPER_BUILD -# define WHISPER_API __declspec(dllexport) -# else -# define WHISPER_API __declspec(dllimport) -# endif -# else -# define WHISPER_API __attribute__ ((visibility ("default"))) -# endif -#else -# define WHISPER_API -#endif - -#define WHISPER_SAMPLE_RATE 16000 -#define WHISPER_N_FFT 400 -#define WHISPER_N_MEL 80 -#define WHISPER_HOP_LENGTH 160 -#define WHISPER_CHUNK_SIZE 30 - -#ifdef __cplusplus -extern "C" { -#endif - - // - // C interface - // - // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads - // concurrently. - // - // Basic usage: - // - // #include "whisper.h" - // - // ... - // - // struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin"); - // - // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { - // fprintf(stderr, "failed to process audio\n"); - // return 7; - // } - // - // const int n_segments = whisper_full_n_segments(ctx); - // for (int i = 0; i < n_segments; ++i) { - // const char * text = whisper_full_get_segment_text(ctx, i); - // printf("%s", text); - // } - // - // whisper_free(ctx); - // - // ... - // - // This is a demonstration of the most straightforward usage of the library. - // "pcmf32" contains the RAW audio data in 32-bit floating point format. - // - // The interface also allows for more fine-grained control over the computation, but it requires a deeper - // understanding of how the model works. - // - - struct whisper_context; - - typedef int whisper_token; - - typedef struct whisper_token_data { - whisper_token id; // token id - whisper_token tid; // forced timestamp token id - - float p; // probability of the token - float plog; // log probability of the token - float pt; // probability of the timestamp token - float ptsum; // sum of probabilities of all timestamp tokens - - // token-level timestamp data - // do not use if you haven't computed token-level timestamps - int64_t t0; // start time of the token - int64_t t1; // end time of the token - - float vlen; // voice length of the token - } whisper_token_data; - - typedef struct whisper_model_loader { - void * context; - - size_t (*read)(void * ctx, void * output, size_t read_size); - bool (*eof)(void * ctx); - void (*close)(void * ctx); - } whisper_model_loader; - - // Various functions for loading a ggml whisper model. - // Allocate (almost) all memory needed for the model. - // Return NULL on failure - WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); - WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); - WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); - - // Frees all memory allocated by the model. - WHISPER_API void whisper_free(struct whisper_context * ctx); - - // Convert RAW PCM audio to log mel spectrogram. - // The resulting spectrogram is stored inside the provided whisper context. - // Returns 0 on success - WHISPER_API int whisper_pcm_to_mel( - struct whisper_context * ctx, - const float * samples, - int n_samples, - int n_threads); - - // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. - // The resulting spectrogram is stored inside the provided whisper context. - // Returns 0 on success - WHISPER_API int whisper_pcm_to_mel_phase_vocoder( - struct whisper_context* ctx, - const float* samples, - int n_samples, - int n_threads); - - - // This can be used to set a custom log mel spectrogram inside the provided whisper context. - // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. - // n_mel must be 80 - // Returns 0 on success - WHISPER_API int whisper_set_mel( - struct whisper_context * ctx, - const float * data, - int n_len, - int n_mel); - - // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. - // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. - // offset can be used to specify the offset of the first frame in the spectrogram. - // Returns 0 on success - WHISPER_API int whisper_encode( - struct whisper_context * ctx, - int offset, - int n_threads); - - // Run the Whisper decoder to obtain the logits and probabilities for the next token. - // Make sure to call whisper_encode() first. - // tokens + n_tokens is the provided context for the decoder. - // n_past is the number of tokens to use from previous decoder calls. - // Returns 0 on success - // TODO: add support for multiple decoders - WHISPER_API int whisper_decode( - struct whisper_context * ctx, - const whisper_token * tokens, - int n_tokens, - int n_past, - int n_threads); - - // Convert the provided text into tokens. - // The tokens pointer must be large enough to hold the resulting tokens. - // Returns the number of tokens on success, no more than n_max_tokens - // Returns -1 on failure - // TODO: not sure if correct - WHISPER_API int whisper_tokenize( - struct whisper_context * ctx, - const char * text, - whisper_token * tokens, - int n_max_tokens); - - // Largest language id (i.e. number of available languages - 1) - WHISPER_API int whisper_lang_max_id(); - - // Return the id of the specified language, returns -1 if not found - // Examples: - // "de" -> 2 - // "german" -> 2 - WHISPER_API int whisper_lang_id(const char * lang); - - // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found - WHISPER_API const char * whisper_lang_str(int id); - - // Use mel data at offset_ms to try and auto-detect the spoken language - // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first - // Returns the top language id or negative on failure - // If not null, fills the lang_probs array with the probabilities of all languages - // The array must be whispe_lang_max_id() + 1 in size - // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 - WHISPER_API int whisper_lang_auto_detect( - struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs); - - WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length - WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); - WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); - - // Token logits obtained from the last call to whisper_decode() - // The logits for the last token are stored in the last row - // Rows: n_tokens - // Cols: n_vocab - WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); - - // Token Id -> String. Uses the vocabulary in the provided context - WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); - - // Special tokens - WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); - - // Task tokens - WHISPER_API whisper_token whisper_token_translate (void); - WHISPER_API whisper_token whisper_token_transcribe(void); - - // Performance information - WHISPER_API void whisper_print_timings(struct whisper_context * ctx); - WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); - - // Print system information - WHISPER_API const char * whisper_print_system_info(void); - - //////////////////////////////////////////////////////////////////////////// - - // Available sampling strategies - enum whisper_sampling_strategy { - WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder - WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder - }; - - // Text segment callback - // Called on every newly generated text segment - // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); - - // Encoder begin callback - // If not NULL, called before the encoder starts - // If it returns false, the computation is aborted - typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); - - // Parameters for the whisper_full() function - // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: - // whisper_full_default_params() - struct whisper_full_params { - enum whisper_sampling_strategy strategy; - - int n_threads; - int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder - int offset_ms; // start offset in ms - int duration_ms; // audio duration to process in ms - - bool translate; - bool no_context; // do not use past transcription (if any) as initial prompt for the decoder - bool single_segment; // force single segment output (useful for streaming) - bool print_special; // print special tokens (e.g. , , , etc.) - bool print_progress; // print progress information - bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) - bool print_timestamps; // print timestamps for each text segment when printing realtime - - // [EXPERIMENTAL] token-level timestamps - bool token_timestamps; // enable token-level timestamps - float thold_pt; // timestamp token probability threshold (~0.01) - float thold_ptsum; // timestamp token sum probability threshold (~0.01) - int max_len; // max segment length in characters - bool split_on_word; // split on word rather than on token (when used with max_len) - int max_tokens; // max tokens per segment (0 = no limit) - - // [EXPERIMENTAL] speed-up techniques - // note: these can significantly reduce the quality of the output - bool speed_up; // speed-up the audio by 2x using Phase Vocoder - int audio_ctx; // overwrite the audio context size (0 = use default) - - // tokens to provide to the whisper decoder as initial prompt - // these are prepended to any existing text context from a previous call - const whisper_token * prompt_tokens; - int prompt_n_tokens; - - // for auto-detection, set to nullptr, "" or "auto" - const char * language; - - // common decoding parameters: - bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 - bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - - float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 - float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 - float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 - - // fallback parameters - // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 - float temperature_inc; - float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" - float logprob_thold; - float no_speech_thold; // TODO: not implemented - - struct { - int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 - } greedy; - - struct { - int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 - - float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf - } beam_search; - - // called for every newly generated text segment - whisper_new_segment_callback new_segment_callback; - void * new_segment_callback_user_data; - - // called each time before the encoder starts - whisper_encoder_begin_callback encoder_begin_callback; - void * encoder_begin_callback_user_data; - }; - - WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); - - // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - // Uses the specified decoding strategy to obtain the text. - WHISPER_API int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples); - - // Split the input audio in chunks and process each chunk separately using whisper_full() - // It seems this approach can offer some speedup in some cases. - // However, the transcription accuracy can be worse at the beginning and end of each chunk. - WHISPER_API int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - int n_processors); - - // Number of generated text segments. - // A segment can be a few words, a sentence, or even a paragraph. - WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); - - // Language id associated with the current context - WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); - - // Get the start and end time of the specified segment. - WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); - - // Get the text of the specified segment. - WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); - - // Get number of tokens in the specified segment. - WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); - - // Get the token text of the specified token in the specified segment. - WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); - - // Get token data for the specified token in the specified segment. - // This contains probabilities, timestamps, etc. - WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); - - // Get the probability of the specified token in the specified segment. - WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); - - //////////////////////////////////////////////////////////////////////////// - - // Temporary helpers needed for exposing ggml interface - - WHISPER_API int whisper_bench_memcpy(int n_threads); - WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); - -#ifdef __cplusplus -} -#endif - -#endif From ec8e2e14e1c63b06c67eb052a55086219a5d6575 Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Tue, 14 Feb 2023 08:01:44 -0500 Subject: [PATCH 3/8] ignore these files here --- bindings/ruby/ext/.gitignore | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 bindings/ruby/ext/.gitignore diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore new file mode 100644 index 00000000000..1031bdd6634 --- /dev/null +++ b/bindings/ruby/ext/.gitignore @@ -0,0 +1,6 @@ +Makefile +ggml.c +ggml.h +whisper.bundle +whisper.cpp +whisper.h From 126458b6cb565cf35d1e6e3910232f70df633579 Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Tue, 14 Feb 2023 09:43:33 -0500 Subject: [PATCH 4/8] add definitions for boolean params --- bindings/ruby/ext/ruby_whisper.c | 146 +++++++++++++++++++++++++++- bindings/ruby/tests/test_whisper.rb | 89 ++++++++++++++++- 2 files changed, 233 insertions(+), 2 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 4abfab524b4..33324a0c08e 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,6 +1,25 @@ #include #include "ruby_whisper.h" +#define BOOL_PARAMS_SETTER(self, prop, value) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (value == Qfalse || value == Qnil) { \ + rwp->params->prop = false; \ + } else { \ + rwp->params->prop = true; \ + } \ + return value; \ + +#define BOOL_PARAMS_GETTER(self, prop) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (rwp->params->prop) { \ + return Qtrue; \ + } else { \ + return Qfalse; \ + } + VALUE mWhisper; VALUE cContext; VALUE cParams; @@ -71,10 +90,108 @@ static VALUE ruby_whisper_params_set_auto_detection(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); if (value == Qfalse || value == Qnil) { + rwp->params->language = NULL; + } else { rwp->params->language = "auto"; + } + return value; +} +static VALUE ruby_whisper_params_get_auto_detection(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->params->language) { + return Qtrue; } else { - rwp->params->language = NULL; + return Qfalse; } +} +static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, translate, value) +} +static VALUE ruby_whisper_params_get_translate(VALUE self) { + BOOL_PARAMS_GETTER(self, translate) +} +static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, no_context, value) +} +static VALUE ruby_whisper_params_get_no_context(VALUE self) { + BOOL_PARAMS_GETTER(self, no_context) +} +static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, single_segment, value) +} +static VALUE ruby_whisper_params_get_single_segment(VALUE self) { + BOOL_PARAMS_GETTER(self, single_segment) +} +static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_special, value) +} +static VALUE ruby_whisper_params_get_print_special(VALUE self) { + BOOL_PARAMS_GETTER(self, print_special) +} +static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_progress, value) +} +static VALUE ruby_whisper_params_get_print_progress(VALUE self) { + BOOL_PARAMS_GETTER(self, print_progress) +} +static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_realtime, value) +} +static VALUE ruby_whisper_params_get_print_realtime(VALUE self) { + BOOL_PARAMS_GETTER(self, print_realtime) +} +static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_timestamps, value) +} +static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) { + BOOL_PARAMS_GETTER(self, print_timestamps) +} +static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, suppress_blank, value) +} +static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { + BOOL_PARAMS_GETTER(self, suppress_blank) +} +static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) +} +static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { + BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) +} + +static VALUE ruby_whisper_params_get_offset(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params->offset_ms); +} +static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params->offset_ms = NUM2INT(value); + return value; +} +static VALUE ruby_whisper_params_get_duration(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params->duration_ms); +} +static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params->duration_ms = NUM2INT(value); + return value; +} + +static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params->n_max_text_ctx); +} +static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params->n_max_text_ctx = NUM2INT(value); return value; } @@ -89,4 +206,31 @@ void Init_whisper() { rb_define_alloc_func(cParams, ruby_whisper_params_allocate); rb_define_method(cParams, "auto_detection=", ruby_whisper_params_set_auto_detection, 1); + rb_define_method(cParams, "auto_detection", ruby_whisper_params_get_auto_detection, 0); + rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1); + rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0); + rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1); + rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0); + rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1); + rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0); + rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0); + rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1); + rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0); + rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1); + rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0); + rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1); + rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0); + rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1); + rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0); + rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); + rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0); + rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); + + rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0); + rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1); + rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0); + rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1); + + rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); + rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); } diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 82c21fe71f3..617a64eabd2 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -10,14 +10,101 @@ class TestWhisper < Test::Unit::TestCase def setup @params = Whisper::Params.new - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) end def test_autodetect @params.auto_detection = true + assert @params.auto_detection + @params.auto_detection = false + assert !@params.auto_detection + end + + def test_offset + @params.offset = 10_000 + assert_equal @params.offset, 10_000 + @params.offset = 0 + assert_equal @params.offset, 0 + end + + def test_duration + @params.duration = 60_000 + assert_equal @params.duration, 60_000 + @params.duration = 0 + assert_equal @params.duration, 0 + end + + def test_max_text_tokens + @params.max_text_tokens = 300 + assert_equal @params.max_text_tokens, 300 + @params.max_text_tokens = 0 + assert_equal @params.max_text_tokens, 0 + end + + def test_translate + @params.translate = true + assert @params.translate + @params.translate = false + assert !@params.translate + end + + def test_no_context + @params.no_context = true + assert @params.no_context + @params.no_context = false + assert !@params.no_context + end + + def test_single_segment + @params.single_segment = true + assert @params.single_segment + @params.single_segment = false + assert !@params.single_segment + end + + def test_print_special + @params.print_special = true + assert @params.print_special + @params.print_special = false + assert !@params.print_special + end + + def test_print_progress + @params.print_progress = true + assert @params.print_progress + @params.print_progress = false + assert !@params.print_progress + end + + def test_print_realtime + @params.print_realtime = true + assert @params.print_realtime + @params.print_realtime = false + assert !@params.print_realtime + end + + def test_print_timestamps + @params.print_timestamps = true + assert @params.print_timestamps + @params.print_timestamps = false + assert !@params.print_timestamps + end + + def test_suppress_blank + @params.suppress_blank = true + assert @params.suppress_blank + @params.suppress_blank = false + assert !@params.suppress_blank + end + + def test_suppress_non_speech_tokens + @params.suppress_non_speech_tokens = true + assert @params.suppress_non_speech_tokens + @params.suppress_non_speech_tokens = false + assert !@params.suppress_non_speech_tokens end def test_whisper + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) end end From 621370290c2412c96de7c9f4ca945fa06f948d71 Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Tue, 14 Feb 2023 10:36:16 -0500 Subject: [PATCH 5/8] initial transcribe for ruby --- bindings/ruby/ext/.gitignore | 1 + bindings/ruby/ext/extconf.rb | 1 + .../ext/{ruby_whisper.c => ruby_whisper.cpp} | 238 ++++++++++++++++-- bindings/ruby/ext/ruby_whisper.h | 3 +- bindings/ruby/tests/test_whisper.rb | 38 ++- 5 files changed, 251 insertions(+), 30 deletions(-) rename bindings/ruby/ext/{ruby_whisper.c => ruby_whisper.cpp} (52%) diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore index 1031bdd6634..7c9cb0378ae 100644 --- a/bindings/ruby/ext/.gitignore +++ b/bindings/ruby/ext/.gitignore @@ -4,3 +4,4 @@ ggml.h whisper.bundle whisper.cpp whisper.h +dr_wav.h diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 0421cb2bb98..851c52dbd0c 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -3,6 +3,7 @@ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .") # need to use c++ compiler flags diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.cpp similarity index 52% rename from bindings/ruby/ext/ruby_whisper.c rename to bindings/ruby/ext/ruby_whisper.cpp index 33324a0c08e..e7416ba26c4 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -1,20 +1,32 @@ #include #include "ruby_whisper.h" +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" +#include +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif #define BOOL_PARAMS_SETTER(self, prop, value) \ ruby_whisper_params *rwp; \ Data_Get_Struct(self, ruby_whisper_params, rwp); \ if (value == Qfalse || value == Qnil) { \ - rwp->params->prop = false; \ + rwp->params.prop = false; \ } else { \ - rwp->params->prop = true; \ + rwp->params.prop = true; \ } \ return value; \ #define BOOL_PARAMS_GETTER(self, prop) \ ruby_whisper_params *rwp; \ Data_Get_Struct(self, ruby_whisper_params, rwp); \ - if (rwp->params->prop) { \ + if (rwp->params.prop) { \ return Qtrue; \ } else { \ return Qfalse; \ @@ -31,10 +43,6 @@ static void ruby_whisper_free(ruby_whisper *rw) { } } static void ruby_whisper_params_free(ruby_whisper_params *rwp) { - if (rwp->params) { - free(rwp->params); - rwp->params = NULL; - } } void rb_whisper_mark(ruby_whisper *rw) { @@ -64,7 +72,7 @@ static VALUE ruby_whisper_allocate(VALUE klass) { static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); - rwp->params = ALLOC(struct whisper_full_params); + rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -80,29 +88,161 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path)); + if (rw->context == nullptr) { + rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); + } return self; } /* - * params.auto_detection = true|false + * transcribe a single file + * can emit to a block results + * + **/ +static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { + ruby_whisper *rw; + ruby_whisper_params *rwp; + VALUE wave_file_path, blk, params; + + rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk); + Data_Get_Struct(self, ruby_whisper, rw); + Data_Get_Struct(params, ruby_whisper_params, rwp); + + if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) { + rb_raise(rb_eRuntimeError, "Expected file path to wave file"); + } + + std::string fname_inp = StringValueCStr(wave_file_path); + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + // WAV input - this is directly from main.cpp example + { + drwav wav; + std::vector wav_data; // used for pipe input from stdin + + if (fname_inp == "-") { + { + uint8_t buf[1024]; + while (true) { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return self; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); + return self; + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str()); + return self; + } + + if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) { + fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str()); + return self; + } + + if (wav.sampleRate != WHISPER_SAMPLE_RATE) { + fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000); + return self; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str()); + return self; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (rwp->diarize) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + } + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + rwp->params.encoder_begin_callback_user_data = &is_aborted; + } + + if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { + fprintf(stderr, "failed to process audio\n"); + return self; + } + const int n_segments = whisper_full_n_segments(rw->context); + VALUE output = rb_str_new2(""); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(rw->context, i); + output = rb_str_concat(output, rb_str_new2(text)); + } + VALUE idCall = rb_intern("call"); + if (blk != Qnil) { + rb_funcall(blk, idCall, 1, output); + } + return self; +} + +/* + * params.language = "auto" | "en", etc... */ -static VALUE ruby_whisper_params_set_auto_detection(VALUE self, VALUE value) { +static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); if (value == Qfalse || value == Qnil) { - rwp->params->language = NULL; + rwp->params.language = "auto"; } else { - rwp->params->language = "auto"; + rwp->params.language = StringValueCStr(value); } return value; } -static VALUE ruby_whisper_params_get_auto_detection(VALUE self) { +static VALUE ruby_whisper_params_get_language(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - if (rwp->params->language) { - return Qtrue; + if (rwp->params.language) { + return rb_str_new2(rwp->params.language); } else { - return Qfalse; + return rb_str_new2("auto"); } } static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { @@ -159,39 +299,76 @@ static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALU static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) } +static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { + BOOL_PARAMS_GETTER(self, token_timestamps) +} +static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, token_timestamps, value) +} +static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { + BOOL_PARAMS_GETTER(self, split_on_word) +} +static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, split_on_word, value) +} +static VALUE ruby_whisper_params_get_speed_up(VALUE self) { + BOOL_PARAMS_GETTER(self, speed_up) +} +static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, speed_up, value) +} +static VALUE ruby_whisper_params_get_diarize(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->diarize) { + return Qtrue; + } else { + return Qfalse; + } +} +static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->diarize = false; + } else { + rwp->diarize = true; + } \ + return value; +} static VALUE ruby_whisper_params_get_offset(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params->offset_ms); + return INT2NUM(rwp->params.offset_ms); } static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params->offset_ms = NUM2INT(value); + rwp->params.offset_ms = NUM2INT(value); return value; } static VALUE ruby_whisper_params_get_duration(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params->duration_ms); + return INT2NUM(rwp->params.duration_ms); } static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params->duration_ms = NUM2INT(value); + rwp->params.duration_ms = NUM2INT(value); return value; } static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params->n_max_text_ctx); + return INT2NUM(rwp->params.n_max_text_ctx); } static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params->n_max_text_ctx = NUM2INT(value); + rwp->params.n_max_text_ctx = NUM2INT(value); return value; } @@ -203,10 +380,12 @@ void Init_whisper() { rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); + rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + rb_define_alloc_func(cParams, ruby_whisper_params_allocate); - rb_define_method(cParams, "auto_detection=", ruby_whisper_params_set_auto_detection, 1); - rb_define_method(cParams, "auto_detection", ruby_whisper_params_get_auto_detection, 0); + rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1); + rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0); rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1); rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0); rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1); @@ -225,6 +404,14 @@ void Init_whisper() { rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0); rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); + rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0); + rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); + rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); + rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); + rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0); + rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1); + rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); + rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0); rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1); @@ -234,3 +421,6 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); } +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 246d13334a1..8c35b7cb65c 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -8,7 +8,8 @@ typedef struct { } ruby_whisper; typedef struct { - struct whisper_full_params *params; + struct whisper_full_params params; + bool diarize; } ruby_whisper_params; #endif diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 617a64eabd2..03d5a233bfd 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -12,11 +12,11 @@ def setup @params = Whisper::Params.new end - def test_autodetect - @params.auto_detection = true - assert @params.auto_detection - @params.auto_detection = false - assert !@params.auto_detection + def test_language + @params.language = "en" + assert_equal @params.language, "en" + @params.language = "auto" + assert_equal @params.language, "auto" end def test_offset @@ -103,8 +103,36 @@ def test_suppress_non_speech_tokens assert !@params.suppress_non_speech_tokens end + def test_token_timestamps + @params.token_timestamps = true + assert @params.token_timestamps + @params.token_timestamps = false + assert !@params.token_timestamps + end + + def test_split_on_word + @params.split_on_word = true + assert @params.split_on_word + @params.split_on_word = false + assert !@params.split_on_word + end + + def test_speed_up + @params.speed_up = true + assert @params.speed_up + @params.speed_up = false + assert !@params.speed_up + end + def test_whisper @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) {|text| + puts text.inspect + } end end From 673ad053e8ba01ab053b3a83382038fe3e51e7ec Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Tue, 14 Feb 2023 10:38:23 -0500 Subject: [PATCH 6/8] use en model and transcribe jfk with assertion --- bindings/ruby/tests/test_whisper.rb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 03d5a233bfd..fa6a3e2d4e8 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -125,13 +125,13 @@ def test_speed_up end def test_whisper - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) params = Whisper::Params.new params.print_timestamps = false jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') @whisper.transcribe(jfk, params) {|text| - puts text.inspect + assert_match /ask not what your country can do for you, ask what you can do for your country/, text } end From d1a2145cfc35b6cf01d287959539e57cd884c79e Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Wed, 15 Feb 2023 09:01:28 -0500 Subject: [PATCH 7/8] possibly this works for building ruby binding --- .github/workflows/build.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1a2baac8064..8ebd6361aa8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -14,11 +14,13 @@ jobs: sudo apt-get update sudo apt-get install build-essential sudo apt-get install libsdl2-dev + sudo apt-get install ruby-dev - name: Build run: | make make stream + cd bindings/ruby/ext && ruby extconf.rb && make macOS-latest: runs-on: macOS-latest @@ -30,12 +32,13 @@ jobs: - name: Dependencies run: | brew update - brew install sdl2 + brew install sdl2 ruby - name: Build run: | make make stream + cd bindings/ruby/ext && ruby extconf.rb && make ubuntu-latest-gcc: runs-on: ubuntu-latest From 27b1a919373a8b32c43277c192737ec8929cbff8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 Feb 2023 19:15:46 +0200 Subject: [PATCH 8/8] ci : try to add ruby workflow --- .../{bindings.yml => bindings-go.yml} | 2 +- .github/workflows/bindings-ruby.yml | 22 +++++++++++++++++++ .github/workflows/build.yml | 5 +---- 3 files changed, 24 insertions(+), 5 deletions(-) rename .github/workflows/{bindings.yml => bindings-go.yml} (93%) create mode 100644 .github/workflows/bindings-ruby.yml diff --git a/.github/workflows/bindings.yml b/.github/workflows/bindings-go.yml similarity index 93% rename from .github/workflows/bindings.yml rename to .github/workflows/bindings-go.yml index 02667edb59f..13f1950a5ba 100644 --- a/.github/workflows/bindings.yml +++ b/.github/workflows/bindings-go.yml @@ -1,4 +1,4 @@ -name: Bindings Tests +name: Bindings Tests (Go) on: push: paths: diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml new file mode 100644 index 00000000000..902dfe6a2b5 --- /dev/null +++ b/.github/workflows/bindings-ruby.yml @@ -0,0 +1,22 @@ +name: Bindings Tests (Ruby) +on: + push: + paths: + - bindings/ruby/** + - whisper.h + pull_request: + paths: + - bindings/ruby/** + - whisper.h + +jobs: + ubuntu-latest: + runs-on: ubuntu-latest + steps: + - uses: ruby/setup-ruby@v1 + with: + ruby-version: '3.0' + - uses: actions/checkout@v1 + - run: | + cd bindings/ruby/ext + ruby extconf.rb && make diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8ebd6361aa8..1a2baac8064 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -14,13 +14,11 @@ jobs: sudo apt-get update sudo apt-get install build-essential sudo apt-get install libsdl2-dev - sudo apt-get install ruby-dev - name: Build run: | make make stream - cd bindings/ruby/ext && ruby extconf.rb && make macOS-latest: runs-on: macOS-latest @@ -32,13 +30,12 @@ jobs: - name: Dependencies run: | brew update - brew install sdl2 ruby + brew install sdl2 - name: Build run: | make make stream - cd bindings/ruby/ext && ruby extconf.rb && make ubuntu-latest-gcc: runs-on: ubuntu-latest