Explorar o código

Merge commit '0870e6bfb0d46d58e29c6d5822e8bf629d03bdcc' as 'tinychat'

Alex Cheema hai 10 meses
pai
achega
e1cb840b31
Modificáronse 100 ficheiros con 12097 adicións e 0 borrados
  1. 435 0
      tinychat/.github/workflows/benchmark.yml
  2. 30 0
      tinychat/.github/workflows/docs.yml
  3. 30 0
      tinychat/.github/workflows/python-publish.yml
  4. 98 0
      tinychat/.github/workflows/szdiff.yml
  5. 533 0
      tinychat/.github/workflows/test.yml
  6. 56 0
      tinychat/.gitignore
  7. 51 0
      tinychat/.pre-commit-config.yaml
  8. 472 0
      tinychat/.pylintrc
  9. 7 0
      tinychat/LICENSE
  10. 178 0
      tinychat/README.md
  11. 289 0
      tinychat/autogen_stubs.sh
  12. 1 0
      tinychat/docs/CNAME
  13. 118 0
      tinychat/docs/abstractions2.py
  14. 62 0
      tinychat/docs/abstractions3.py
  15. 56 0
      tinychat/docs/developer.md
  16. 9 0
      tinychat/docs/dtypes.md
  17. 52 0
      tinychat/docs/env_vars.md
  18. 25 0
      tinychat/docs/favicon.svg
  19. 33 0
      tinychat/docs/function.md
  20. 47 0
      tinychat/docs/index.md
  21. 11 0
      tinychat/docs/logo_tiny_dark.svg
  22. 11 0
      tinychat/docs/logo_tiny_light.svg
  23. 177 0
      tinychat/docs/mnist.md
  24. 31 0
      tinychat/docs/nn.md
  25. 308 0
      tinychat/docs/quickstart.md
  26. 146 0
      tinychat/docs/runtime/hcq.md
  27. 51 0
      tinychat/docs/runtime/overview.md
  28. 62 0
      tinychat/docs/showcase.md
  29. BIN=BIN
      tinychat/docs/showcase/mnist_by_tinygrad.jpg
  30. BIN=BIN
      tinychat/docs/showcase/stable_diffusion_by_tinygrad.jpg
  31. BIN=BIN
      tinychat/docs/showcase/yolo_by_tinygrad.jpg
  32. BIN=BIN
      tinychat/docs/showcase/yolov8_showcase_image.png
  33. 24 0
      tinychat/docs/tensor/creation.md
  34. 36 0
      tinychat/docs/tensor/index.md
  35. 26 0
      tinychat/docs/tensor/movement.md
  36. 113 0
      tinychat/docs/tensor/ops.md
  37. 55 0
      tinychat/docs/tinybox.md
  38. BIN=BIN
      tinychat/docs/tinygrad_intro.pdf
  39. 0 0
      tinychat/examples/__init__.py
  40. 129 0
      tinychat/examples/beautiful_cartpole.py
  41. 49 0
      tinychat/examples/beautiful_mnist.py
  42. 56 0
      tinychat/examples/beautiful_mnist_multigpu.py
  43. 95 0
      tinychat/examples/coder.py
  44. 69 0
      tinychat/examples/compile_efficientnet.py
  45. 100 0
      tinychat/examples/compile_tensorflow.py
  46. 343 0
      tinychat/examples/conversation.py
  47. 13 0
      tinychat/examples/conversation_data/pre_prompt_gary.yaml
  48. 20 0
      tinychat/examples/conversation_data/pre_prompt_george.yaml
  49. 16 0
      tinychat/examples/conversation_data/pre_prompt_lexie.yaml
  50. 15 0
      tinychat/examples/conversation_data/pre_prompt_stacy.yaml
  51. 89 0
      tinychat/examples/efficientnet.py
  52. 215 0
      tinychat/examples/gpt2.py
  53. 126 0
      tinychat/examples/handcode_opt.py
  54. 431 0
      tinychat/examples/hlb_cifar10.py
  55. 124 0
      tinychat/examples/index.html
  56. 510 0
      tinychat/examples/llama.py
  57. 446 0
      tinychat/examples/llama3.py
  58. 3 0
      tinychat/examples/llm.c/.gitignore
  59. 106 0
      tinychat/examples/llm.c/export.py
  60. 192 0
      tinychat/examples/llm.c/train_gpt2.py
  61. 65 0
      tinychat/examples/llm.c/ubench/matmul.c
  62. 316 0
      tinychat/examples/mamba.py
  63. 299 0
      tinychat/examples/mask_rcnn.py
  64. 59 0
      tinychat/examples/mixtral.py
  65. 19 0
      tinychat/examples/mlperf/README
  66. 378 0
      tinychat/examples/mlperf/dataloader.py
  67. 240 0
      tinychat/examples/mlperf/helpers.py
  68. 68 0
      tinychat/examples/mlperf/initializers.py
  69. 6 0
      tinychat/examples/mlperf/losses.py
  70. 22 0
      tinychat/examples/mlperf/lr_schedulers.py
  71. 61 0
      tinychat/examples/mlperf/metrics.py
  72. 252 0
      tinychat/examples/mlperf/model_eval.py
  73. 70 0
      tinychat/examples/mlperf/model_spec.py
  74. 691 0
      tinychat/examples/mlperf/model_train.py
  75. 50 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/README.md
  76. 13 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/dev_beam.sh
  77. 15 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/dev_run.sh
  78. 23 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/run_and_time.sh
  79. 50 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/README.md
  80. 13 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/dev_beam.sh
  81. 15 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/dev_run.sh
  82. 23 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/run_and_time.sh
  83. 8 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/setup.sh
  84. 38 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/systems/tinybox_green.json
  85. 38 0
      tinychat/examples/mlperf/training_submission_v4.0/tinycorp/systems/tinybox_red.json
  86. 107 0
      tinychat/examples/mnist_gan.py
  87. 118 0
      tinychat/examples/openelm.py
  88. 204 0
      tinychat/examples/openpilot/compile2.py
  89. 2 0
      tinychat/examples/openpilot/go.sh
  90. 55 0
      tinychat/examples/other_mnist/beautiful_mnist_mlx.py
  91. 55 0
      tinychat/examples/other_mnist/beautiful_mnist_torch.py
  92. 45 0
      tinychat/examples/rl/lightupbutton.py
  93. 147 0
      tinychat/examples/sdv2.py
  94. 428 0
      tinychat/examples/sdxl.py
  95. BIN=BIN
      tinychat/examples/sdxl_seed0.png
  96. 136 0
      tinychat/examples/serious_mnist.py
  97. 17 0
      tinychat/examples/simple_conv_bn.py
  98. 673 0
      tinychat/examples/so_vits_svc.py
  99. 204 0
      tinychat/examples/sovits_helpers/preprocess.py
  100. 294 0
      tinychat/examples/stable_diffusion.py

+ 435 - 0
tinychat/.github/workflows/benchmark.yml

@@ -0,0 +1,435 @@
+name: Benchmarks
+env:
+  RUN_PROCESS_REPLAY: "1"
+  ASSERT_PROCESS_REPLAY: "0"
+  PYTHONPATH: .
+
+on:
+  push:
+    branches:
+      - master
+      - update_benchmark
+      - update_benchmark_staging
+  workflow_dispatch:
+    inputs:
+      run_process_replay:
+        description: "Run process replay tests"
+        required: false
+        default: false
+        type: boolean
+
+jobs:
+  testmacbenchmark:
+    name: Mac Benchmark
+    runs-on: [self-hosted, macOS]
+    defaults:
+      run:
+        shell: bash -o pipefail {0}
+    if: github.repository_owner == 'tinygrad'
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+    - name: Symlink models and datasets
+      run: |
+        mkdir -p weights
+        ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
+        ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
+        ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
+        ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
+        ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
+    - name: Run Stable Diffusion
+      run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
+    - name: Run Stable Diffusion with fp16
+      run: JIT=2 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt
+    - name: Run SDXL
+      run: JIT=2 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
+    - name: Run model inference benchmark
+      run: METAL=1 python3 test/external/external_model_benchmark.py
+    - name: Test speed vs torch
+      run: BIG=2 MPS=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
+    - name: Test tensor cores
+      run: METAL=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
+    - name: Run Tensor Core GEMM
+      run: |
+        DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
+        DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt
+    - name: Fuzz Padded Tensor Core GEMM
+      run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
+    - name: Run LLaMA
+      run: |
+        JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
+        JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
+    - name: Run LLaMA with BEAM
+      run: JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
+    - name: Run quantized LLaMA
+      run: |
+        python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
+        python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
+    - name: Run LLaMA 7B on 4 (virtual) GPUs
+      run: python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0  --timing | tee llama_four_gpu.txt
+    - name: Run GPT2
+      run: |
+        JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
+        JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
+    - name: Run GPT2 w HALF
+      run: HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
+    - name: Run GPT2 w HALF/BEAM
+      run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
+    - name: Train MNIST
+      run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
+    - name: Run 10 CIFAR training steps
+      run: JIT=2 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
+    - name: Run 10 CIFAR training steps w HALF
+      run: JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
+    #- name: Run 10 CIFAR training steps w BF16
+    #  run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
+    - name: Run 10 CIFAR training steps w winograd
+      run: JIT=2 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
+    - name: Run process replay tests
+      if: env.RUN_PROCESS_REPLAY == '1'
+      run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
+    - uses: actions/upload-artifact@v4
+      with:
+        name: Speed (Mac)
+        path: |
+          onnx_inference_speed.csv
+          torch_speed.txt
+          llama_unjitted.txt
+          llama_jitted.txt
+          llama_beam.txt
+          llama_int8.txt
+          llama_nf4.txt
+          llama_four_gpu.txt
+          gpt2_unjitted.txt
+          gpt2_jitted.txt
+          gpt2_half.txt
+          gpt2_half_beam.txt
+          matmul.txt
+          matmul_half.txt
+          sd.txt
+          sd_fp16.txt
+          sdxl.txt
+          beautiful_mnist.txt
+          train_cifar.txt
+          train_cifar_half.txt
+          train_cifar_bf16.txt
+          train_cifar_wino.txt
+
+  testnvidiabenchmark:
+    name: tinybox green Benchmark
+    runs-on: [self-hosted, Linux, tinyboxgreen]
+    defaults:
+      run:
+        shell: bash -o pipefail {0}
+    if: github.repository_owner == 'tinygrad'
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+    - name: Print nvidia-smi
+      run: nvidia-smi
+    - name: Symlink models and datasets
+      run: |
+        mkdir -p weights
+        ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
+        ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
+        ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
+        ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
+        mkdir -p extra/datasets
+        ln -s /raid/datasets/imagenet extra/datasets/imagenet
+    - name: Run model inference benchmark
+      run: NV=1 NOCLANG=1 python3 test/external/external_model_benchmark.py
+    - name: Test speed vs torch
+      run: NV=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
+    - name: Test tensor cores
+      run: |
+        NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
+        PTX=1 NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
+    - name: Run Tensor Core GEMM (CUDA)
+      run: |
+        CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
+        CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
+    - name: Run Tensor Core GEMM (PTX)
+      run: NV=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
+    - name: Run Tensor Core GEMM (NV)
+      run: NV=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
+    - name: Run Tensor Core GEMM (NV) with BEAM
+      run: BEAM=4 NV=1 HALF=1 IGNORE_BEAM_CACHE=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
+    - name: Fuzz Padded Tensor Core GEMM (NV)
+      run: NV=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
+    - name: Fuzz Padded Tensor Core GEMM (PTX)
+      run: NV=1 PTX=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
+    - name: Run Stable Diffusion
+      run: NV=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
+    - name: Run SDXL
+      run: NV=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
+    - name: Run LLaMA
+      run: |
+        NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
+        NV=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
+    - name: Run LLaMA with BEAM
+      run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
+    - name: Run LLaMA 7B on 4 GPUs
+      run: NV=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0  --timing | tee llama_four_gpu.txt
+    - name: Run LLaMA 7B on 6 GPUs
+      run: NV=1 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing | tee llama_six_gpu.txt
+    # TODO: this is flaky
+    # - name: Run LLaMA-3 8B BEAM
+    #   run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_beam.txt
+    - name: Run LLaMA-3 8B on 4 GPUs
+      run: NV=1 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_four_gpu.txt
+    - name: Run LLaMA-3 8B on 6 GPUs
+      run: NV=1 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_six_gpu.txt
+    # - name: Run LLaMA-2 70B
+    #   run: CUDA=1 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing | tee llama_2_70B.txt
+    - name: Run Mixtral 8x7B
+      run: time NV=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
+    - name: Run GPT2
+      run: |
+        NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
+        NV=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
+    - name: Run GPT2 w HALF
+      run: NV=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
+    - name: Run GPT2 w HALF/BEAM
+      run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
+    - name: Run process replay tests
+      if: env.RUN_PROCESS_REPLAY == '1'
+      run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
+    - uses: actions/upload-artifact@v4
+      with:
+        name: Speed (NVIDIA)
+        path: |
+          onnx_inference_speed.csv
+          torch_speed.txt
+          matmul.txt
+          matmul_bfloat16.txt
+          matmul_ptx.txt
+          matmul_nv.txt
+          sd.txt
+          sdxl.txt
+          llama_unjitted.txt
+          llama_jitted.txt
+          llama_beam.txt
+          llama_four_gpu.txt
+          llama_six_gpu.txt
+          llama3_beam.txt
+          llama3_four_gpu.txt
+          llama3_six_gpu.txt
+          # llama_2_70B.txt
+          mixtral.txt
+          gpt2_unjitted.txt
+          gpt2_jitted.txt
+          gpt2_half.txt
+          gpt2_half_beam.txt
+
+  testmorenvidiabenchmark:
+    name: tinybox green Training Benchmark
+    runs-on: [self-hosted, Linux, tinyboxgreen]
+    defaults:
+      run:
+        shell: bash -o pipefail {0}
+    if: github.repository_owner == 'tinygrad'
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+    - name: Symlink models and datasets
+      run: |
+        mkdir -p weights
+        ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
+        ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
+        ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
+        ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
+        ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
+        mkdir -p extra/datasets
+        ln -s /raid/datasets/imagenet extra/datasets/imagenet
+    - name: Train MNIST
+      run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
+    - name: Run 10 CIFAR training steps
+      run: NV=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
+    - name: Run 10 CIFAR training steps w HALF
+      run: NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
+    - name: Run 10 CIFAR training steps w BF16
+      run: NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
+    - name: Run 10 CIFAR training steps w winograd
+      run: NV=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
+    - name: Run full CIFAR training w 1 GPU
+      run: time NV=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
+    - name: Run full CIFAR training steps w 6 GPUS
+      run: time NV=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
+    - name: Run MLPerf resnet eval on training data
+      run: time NV=1 MODEL=resnet python3 examples/mlperf/model_eval.py
+    - name: Run 10 MLPerf ResNet50 training steps (1 gpu)
+      run: NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
+    - name: Run 10 MLPerf ResNet50 training steps (6 gpu)
+      run: NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
+    - name: Run process replay tests
+      if: env.RUN_PROCESS_REPLAY == '1'
+      run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
+    - uses: actions/upload-artifact@v4
+      with:
+        name: Speed (NVIDIA Training)
+        path: |
+          beautiful_mnist.txt
+          train_cifar.txt
+          train_cifar_half.txt
+          train_cifar_bf16.txt
+          train_cifar_wino.txt
+          train_cifar_one_gpu.txt
+          train_resnet.txt
+          train_resnet_one_gpu.txt
+          train_cifar_six_gpu.txt
+
+  testamdbenchmark:
+    name: tinybox red Benchmark
+    runs-on: [self-hosted, Linux, tinybox]
+    defaults:
+      run:
+        shell: bash -o pipefail {0}
+    if: github.repository_owner == 'tinygrad'
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+    - name: Symlink models and datasets
+      run: |
+        mkdir -p weights
+        ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
+        ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
+        ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
+        ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
+        ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
+        ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
+        mkdir -p extra/datasets
+        ln -s /raid/datasets/imagenet extra/datasets/imagenet
+    - name: Show off tinybox
+      run: /opt/rocm/bin/rocm-bandwidth-test
+    # TODO: unstable on AMD
+    #- name: Run model inference benchmark
+    #  run: LD_PRELOAD="/opt/rocm/lib/libhsa-runtime64.so" HSA=1 NOCLANG=1 python3 test/external/external_model_benchmark.py
+    # TODO: unstable on AMD
+    #- name: Test speed vs torch
+    #  run: |
+    #    python3 -c "import torch; print(torch.__version__)"
+    #    LD_PRELOAD="/opt/rocm/lib/libhsa-runtime64.so" HSA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
+    - name: Test tensor cores
+      run: |
+        AMD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
+    - name: Run Tensor Core GEMM (AMD)
+      run: AMD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
+    # TODO: AMD compiler bug causes this to fail
+    #- name: Fuzz Padded Tensor Core GEMM
+    #  run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
+    - name: Run Stable Diffusion
+      run: AMD=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
+    - name: Run SDXL
+      run: AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
+    - name: Run LLaMA 7B
+      run: |
+        AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
+        AMD=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
+    - name: Run LLaMA 7B with BEAM
+      run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
+    - name: Run LLaMA 7B on 4 GPUs
+      run: AMD=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0  --timing | tee llama_four_gpu.txt
+    - name: Run LLaMA 7B on 6 GPUs
+      run: AMD=1 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing | tee llama_six_gpu.txt
+    - name: Run LLaMA-3 8B BEAM
+      run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_beam.txt
+    - name: Run LLaMA-3 8B on 4 GPUs
+      run: AMD=1 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_four_gpu.txt
+    - name: Run LLaMA-3 8B on 6 GPUs
+      run: AMD=1 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_six_gpu.txt
+    - name: Run LLaMA-2 70B
+      run: AMD=1 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0  --timing | tee llama_2_70B.txt
+    - name: Run Mixtral 8x7B
+      run: time AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
+    - name: Run GPT2
+      run: |
+        AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
+        AMD=1 JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
+    - name: Run GPT2 w HALF
+      run: AMD=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
+    - name: Run GPT2 w HALF/BEAM
+      run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
+    - name: Run process replay tests
+      if: env.RUN_PROCESS_REPLAY == '1'
+      run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
+    - uses: actions/upload-artifact@v4
+      with:
+        name: Speed (AMD)
+        path: |
+          onnx_inference_speed.csv
+          torch_speed.txt
+          llama_unjitted.txt
+          llama_jitted.txt
+          llama_beam.txt
+          llama_four_gpu.txt
+          llama_six_gpu.txt
+          llama3_beam.txt
+          llama3_four_gpu.txt
+          llama3_six_gpu.txt
+          llama_2_70B.txt
+          gpt2_unjitted.txt
+          gpt2_jitted.txt
+          gpt2_half.txt
+          gpt2_half_beam.txt
+          matmul.txt
+          matmul_amd.txt
+          sd.txt
+          sdxl.txt
+          mixtral.txt
+
+  testmoreamdbenchmark:
+    name: tinybox red Training Benchmark
+    runs-on: [self-hosted, Linux, tinybox]
+    defaults:
+      run:
+        shell: bash -o pipefail {0}
+    if: github.repository_owner == 'tinygrad'
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+    - name: Symlink models and datasets
+      run: |
+        mkdir -p weights
+        ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
+        ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
+        ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
+        ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
+        ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
+        mkdir -p extra/datasets
+        ln -s /raid/datasets/imagenet extra/datasets/imagenet
+    - name: Train MNIST
+      run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
+    - name: Run 10 CIFAR training steps
+      run: AMD=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
+    - name: Run 10 CIFAR training steps w HALF
+      run: AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
+    - name: Run 10 CIFAR training steps w BF16
+      run: AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
+    - name: Run 10 CIFAR training steps w winograd
+      run: AMD=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
+    - name: Run full CIFAR training w 1 GPU
+      run: time AMD=1 DEFAULT_FLOAT=HALF LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
+    - name: Run full CIFAR training steps w 6 GPUS
+      run: time AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.2 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
+    - name: Run MLPerf resnet eval
+      run: time AMD=1 MODEL=resnet python3 examples/mlperf/model_eval.py
+    - name: Run 10 MLPerf ResNet50 training steps (1 gpu)
+      run: AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
+    - name: Run 10 MLPerf ResNet50 training steps (6 gpu)
+      run: AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
+    - name: Run process replay tests
+      if: env.RUN_PROCESS_REPLAY == '1'
+      run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
+    - uses: actions/upload-artifact@v4
+      with:
+        name: Speed (AMD Training)
+        path: |
+          beautiful_mnist.txt
+          train_cifar.txt
+          train_cifar_half.txt
+          train_cifar_bf16.txt
+          train_cifar_wino.txt
+          train_cifar_one_gpu.txt
+          train_resnet.txt
+          train_resnet_one_gpu.txt
+          train_cifar_six_gpu.txt

+ 30 - 0
tinychat/.github/workflows/docs.yml

@@ -0,0 +1,30 @@
+name: Deploy Docs
+on:
+  push:
+    branches:
+      - master
+      - mkdocs
+permissions:
+  contents: write
+jobs:
+  deploy:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v4
+      - name: Configure Git Credentials
+        run: |
+          git config user.name github-actions[bot]
+          git config user.email 41898282+github-actions[bot]@users.noreply.github.com
+      - uses: actions/setup-python@v5
+        with:
+          python-version: 3.x
+      - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
+      - uses: actions/cache@v4
+        with:
+          key: mkdocs-material-${{ env.cache_id }}
+          path: .cache
+          restore-keys: |
+            mkdocs-material-
+      - run: pip install -e .[docs]
+      - run: mkdocs build --strict
+      - run: mkdocs gh-deploy --force

+ 30 - 0
tinychat/.github/workflows/python-publish.yml

@@ -0,0 +1,30 @@
+# This workflows will upload a Python Package using Twine when a release is created
+# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
+
+name: Upload Python Package
+
+on:
+  release:
+    types: [published]
+  workflow_dispatch:
+
+jobs:
+  deploy:
+    runs-on: ubuntu-latest
+    steps:
+    - uses: actions/checkout@v4
+    - name: Set up Python
+      uses: actions/setup-python@v2
+      with:
+        python-version: '3.x'
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        pip install setuptools wheel twine
+    - name: Build and publish
+      env:
+        TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+        TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+      run: |
+        python setup.py sdist bdist_wheel
+        twine upload dist/*

+ 98 - 0
tinychat/.github/workflows/szdiff.yml

@@ -0,0 +1,98 @@
+name: Check Line Counts
+on:
+  pull_request_target:
+
+# Cancel the workflow in progress in newer build is about to start.
+concurrency:
+  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+  cancel-in-progress: true
+
+jobs:
+  checkbranch:
+    name: Check PR Branch status
+    runs-on: ubuntu-latest
+    outputs:
+      branchstat: ${{ steps.brstat.outputs.stat}}
+    steps:
+      - name: Check code from PR branch 
+        uses: actions/checkout@v4
+        with:
+          repository: ${{ github.event.pull_request.head.repo.full_name }}
+          ref: ${{ github.event.pull_request.head.sha }}
+          fetch-depth: 0
+      - name: Check whether branch is up-to-date
+        id: brstat
+        run: |
+          git remote add tinygrad https://github.com/tinygrad/tinygrad
+          git fetch tinygrad master
+          echo "${{ github.event.pull_request.head.sha }}"
+          git rev-list --left-right --count  tinygrad/master...${{ github.event.pull_request.head.sha }} | awk '{print "Behind "$1" - Ahead "$2""}'
+          count=$(git rev-list --left-right --count  tinygrad/master...${{ github.event.pull_request.head.sha }} | awk '{print $1}')
+          if [ $count -gt 0 ]
+          then
+            echo "Current branch is behind tinygrad master branch!"
+            echo "stat=true" >> "$GITHUB_OUTPUT"
+          else
+            echo "stat=false" >> "$GITHUB_OUTPUT"
+          fi
+
+  szdiff:
+    name: Core Library Line Difference
+    permissions:
+      contents: read
+      pull-requests: write
+    runs-on: ubuntu-latest
+    needs: checkbranch
+    if: needs.checkbranch.outputs.branchstat == 'false'
+    steps:
+      - name: Checkout code from PR branch
+        uses: actions/checkout@v4
+        with:
+          repository: ${{ github.event.pull_request.head.repo.full_name }}
+          ref: ${{ github.event.pull_request.head.sha }}
+          path: pr
+        # the base default to tinygrad master and cannot be other fork branch for security purpose
+      - name: Checkout code from tinygrad master
+        uses: actions/checkout@v4
+        with:
+          path: base
+      - name: Set up Python 3.10
+        uses: actions/setup-python@v5
+        with:
+          python-version: '3.10'
+      - name: Count Line Diff
+        run: |
+          pip install tabulate
+          BASE="$GITHUB_WORKSPACE/base"
+          PR="$GITHUB_WORKSPACE/pr"
+          cp "$BASE/sz.py" .
+          echo "loc_content<<EOF" >> "$GITHUB_ENV"
+          python sz.py "$BASE" "$PR" >> "$GITHUB_ENV"
+          echo "EOF" >> "$GITHUB_ENV"
+      - name: Comment Code Line Diff
+        continue-on-error: false
+        uses: marocchino/sticky-pull-request-comment@v2
+        with:
+          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+          ignore_empty: true
+          skip_unchanged: true
+          recreate: true
+          message: ${{ env.loc_content }}
+
+  rebase:
+    name: Core Library Line Difference
+    permissions:
+      pull-requests: write
+    runs-on: ubuntu-latest
+    needs: checkbranch
+    if: needs.checkbranch.outputs.branchstat == 'true'
+    steps:
+      - name: Comment Rebase
+        continue-on-error: false
+        uses: marocchino/sticky-pull-request-comment@v2
+        with:
+          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+          skip_unchanged: true
+          recreate: true
+          message: |
+            This branch currently is behind tinygrad/master. The line count difference bot is disabled.

+ 533 - 0
tinychat/.github/workflows/test.yml

@@ -0,0 +1,533 @@
+name: Unit Tests
+env:
+  # increment this when downloads substantially change to avoid the internet
+  DOWNLOAD_CACHE_VERSION: '5'
+  RUN_PROCESS_REPLAY: 1
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+  workflow_dispatch:
+
+jobs:
+  uops:
+    name: uops tests
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+    - name: Set up Python 3.12
+      uses: actions/setup-python@v5
+      with:
+        python-version: 3.12
+    - name: Cache python packages
+      uses: actions/cache@v4
+      with:
+        path: ${{ env.Python3_ROOT_DIR }}/lib/python3.12/site-packages
+        key: uops-packages-${{ hashFiles('**/setup.py') }}-3.12
+    - name: Install dependencies
+      run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
+    - name: Test IMAGE=2 support
+      run: |
+        IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
+        IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d
+    - name: Test emulated METAL tensor cores
+      run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm
+    - name: Test emulated AMD tensor cores
+      run: |
+        PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
+        PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
+        PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
+        PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
+    - name: Test emulated CUDA tensor cores
+      run: DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
+    - name: Full test tensor cores
+      run: |
+        PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
+        PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
+        PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
+    - name: Test dtype with Python emulator
+      run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
+    - name: Test ops with Python emulator
+      run: DEBUG=2 PYTHON=1 python3 -m pytest -n=auto test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int)" --durations=20
+    - name: Test uops with Python emulator
+      run: PYTHON=1 python3 -m pytest test/test_uops.py --durations=20
+    - name: Test symbolic with Python emulator
+      run: PYTHONPATH=. PYTHON=1 python3 test/test_symbolic_ops.py
+    - name: test_linearizer_failures with Python emulator
+      run: PYTHONPATH=. PYTHON=1 python3 -m pytest -rA test/test_linearizer_failures.py::TestLinearizerFailures::test_failure_1
+
+  linter:
+    name: Linters
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+
+    # TODO: run the pre-commit hook to replace a lot of this
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+    - name: Set up Python 3.8
+      uses: actions/setup-python@v5
+      with:
+        python-version: 3.8
+    - name: Cache python packages
+      uses: actions/cache@v4
+      with:
+        path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
+        key: linting-packages-${{ hashFiles('**/setup.py') }}-3.8
+    - name: Install dependencies
+      run: pip install -e '.[linting,testing,docs]' --extra-index-url https://download.pytorch.org/whl/cpu
+    - name: Lint with pylint
+      run: python -m pylint --disable=all -e W0311 -e C0303 --jobs=0 --indent-string='  ' **/*.py
+    - name: Lint with ruff
+      run: |
+        pip3 install --upgrade --force-reinstall ruff
+        python3 -m ruff check .
+    - name: Lint tinygrad with pylint
+      run: python -m pylint tinygrad/
+    - name: Run mypy
+      run: python -m mypy --strict-equality
+    - name: Test Docs
+      run: |
+        python docs/abstractions2.py
+        python docs/abstractions3.py
+    - name: Test Docs Build
+      run: mkdocs build --strict
+    - name: Test Quickstart
+      run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py &&  PYTHONPATH=. python quickstart.py
+    - name: Test README
+      run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py &&  PYTHONPATH=. python README.py
+    - name: Fuzz Test symbolic
+      run: python test/external/fuzz_symbolic.py
+    - name: Fuzz Test shapetracker
+      run: |
+        PYTHONPATH="." python test/external/fuzz_shapetracker.py
+        PYTHONPATH="." python test/external/fuzz_shapetracker_math.py
+    - name: Test to_movement_ops
+      run: PYTHONPATH="." python extra/to_movement_ops.py
+    - name: Use as an external package
+      run: |
+        mkdir $HOME/test_external_dir
+        cd $HOME/test_external_dir
+        python -m venv venv
+        source venv/bin/activate
+        pip install $GITHUB_WORKSPACE
+        python -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))"
+    - name: Test DEBUG
+      run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
+    - name: Repo line count <8500 lines
+      run: MAX_LINE_COUNT=8500 python sz.py
+
+  testopencl:
+    strategy:
+      fail-fast: false
+      matrix:
+        task: [optimage, onnx]
+
+    name: ${{ matrix.task=='optimage'&&'GPU IMAGE+compile Tests' || matrix.task=='onnx'&&'ONNX+Optimization Tests' }}
+    runs-on: ubuntu-20.04
+    timeout-minutes: 10
+
+    steps:
+      - name: Checkout Code
+        uses: actions/checkout@v4
+      - name: Install OpenCL
+        run: |
+          echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel
+          echo "deb [ allow-insecure=yes ] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list
+          sudo apt update || true
+          sudo apt install --allow-unauthenticated -y --no-install-recommends \
+            intel-oneapi-runtime-openmp=2023.2.1-16 intel-oneapi-runtime-compilers-common=2023.2.1-16 intel-oneapi-runtime-compilers=2023.2.1-16 \
+            intel-oneapi-runtime-dpcpp-sycl-opencl-cpu=2023.2.1-16 intel-oneapi-runtime-tbb-common=2021.10.0-49541 \
+            intel-oneapi-runtime-tbb=2021.10.0-49541 intel-oneapi-runtime-opencl=2023.2.1-16
+      - name: Set up Python 3.11
+        uses: actions/setup-python@v5
+        with:
+          python-version: 3.11
+      - name: Cache python packages
+        uses: actions/cache@v4
+        with:
+          path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
+          key: testing-packages-${{ hashFiles('**/setup.py') }}
+      - name: Cache downloads
+        uses: actions/cache@v4
+        with:
+          path: ~/.cache/tinygrad/downloads/
+          key: downloads-cache-${{ matrix.task }}-${{ env.DOWNLOAD_CACHE_VERSION }}
+      - name: Install Dependencies
+        run: pip install -e '.[testing,testing_tf]' --extra-index-url https://download.pytorch.org/whl/cpu
+      - if: ${{ matrix.task == 'optimage' }}
+        name: Run Kernel Count Test
+        run: PYTHONPATH="." GPU=1 python -m pytest -n=auto test/external/external_test_opt.py
+      - if: ${{ matrix.task == 'optimage'}}
+        name: Test WINO=1
+        run: GPU=1 DEBUG=2 WINO=1 python3 test/test_ops.py TestOps.test_simple_conv2d
+      - if: ${{ matrix.task == 'optimage'}}
+        name: Test GPU IMAGE=2 ops + training
+        run: |
+          GPU=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py
+          GPU=1 IMAGE=2 python3 test/models/test_end2end.py TestEnd2End.test_linear_mnist
+      - if: ${{ matrix.task == 'optimage' }}
+        name: Test openpilot model compile and size
+        run: |
+          PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
+          python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
+      - if: ${{ matrix.task == 'optimage' }}
+        name: Test openpilot model correctness (float32)
+        run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
+      - if: ${{ matrix.task == 'optimage' }}
+        name: Test openpilot alt model correctness (float32)
+        run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
+      - if: ${{ matrix.task == 'optimage' }}
+        name: Test openpilot fastvits model correctness (float32)
+        run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
+      - if: ${{ matrix.task == 'optimage' }}
+        name: Compile EfficientNet to C and test it
+        run: |
+          CLANG=1 PYTHONPATH="." python examples/compile_efficientnet.py > recognize.c
+          clang -O2 recognize.c -lm -o recognize
+          cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test ONNX (GPU)
+        run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test ONNX (CLANG)
+        run: CLANG=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test Action Space
+        run: PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test Beam Search
+        run: PYTHONPATH="." GPU=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test MLPerf optimizers
+        run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py --durations=20
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test MLPerf losses
+        run: GPU=1 python -m pytest -n=auto test/external/external_test_losses.py --durations=20
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test MLPerf metrics
+        run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test MLPerf datasets
+        run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Test THREEFRY
+        run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py --durations=20
+      - if: ${{ matrix.task == 'onnx' }}
+        name: Run handcode_opt
+        run: PYTHONPATH=. MODEL=resnet GPU=1 DEBUG=1 BS=4 HALF=0 python3 examples/handcode_opt.py
+
+  #testwebgpu:
+  #  name: WebGPU Tests
+  #  runs-on: macos-13
+  #  timeout-minutes: 10
+  #  steps:
+  #  - name: Checkout Code
+  #    uses: actions/checkout@v4
+  #  - name: Set up Python 3.11
+  #    uses: actions/setup-python@v5
+  #    with:
+  #      python-version: 3.11
+  #  - name: Cache python packages
+  #    uses: actions/cache@v4
+  #    with:
+  #      path: /Users/runner/Library/Python/3.11/lib/python/site-packages
+  #      key: webgpu-testing-user3-packages-${{ hashFiles('**/setup.py') }}
+  #  - name: Install Dependencies
+  #    run: pip install --user -e '.[webgpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
+  #  - name: Cache downloads
+  #    uses: actions/cache@v4
+  #    with:
+  #      path: ~/Library/Caches/tinygrad/downloads/
+  #      key: downloads-cache-webgpu-${{ env.DOWNLOAD_CACHE_VERSION }}
+  #  - name: Check Device.DEFAULT (WEBGPU) and print some source
+  #    run: |
+  #      WEBGPU=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
+  #      WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
+    #- name: Run webgpu pytest
+    #  run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto
+  #  - name: Run selected webgpu tests
+  #    run: |
+  #      WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_ops.py test/test_dtype.py \
+  #      test/test_jit.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_linearizer.py \
+  #      test/test_linearizer_failures.py test/test_nn.py
+  #  - name: Build WEBGPU Efficientnet
+  #    run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
+  #  - name: Install Puppeteer
+  #    run: npm install puppeteer
+  #  - name: Run WEBGPU Efficientnet
+  #    run: node test/web/test_webgpu.js
+
+  testmetal:
+    name: Metal Tests
+    runs-on: macos-14
+    timeout-minutes: 20
+
+    steps:
+    - name: Checkout Code
+      uses: actions/checkout@v4
+      with:
+        fetch-depth: 2 # NOTE: this fetches the HEAD commit of the PR
+    - name: Set up Python 3.11
+      uses: actions/setup-python@v5
+      with:
+        python-version: 3.11
+    - name: Cache python packages
+      uses: actions/cache@v4
+      with:
+        path: /Users/runner/Library/Python/3.11/lib/python/site-packages
+        key: metal-m1-testing-user3-packages-${{ hashFiles('**/setup.py') }}
+    - name: Install Dependencies
+      run: pip install --user -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
+    - name: Cache downloads
+      uses: actions/cache@v4
+      with:
+        path: ~/Library/Caches/tinygrad/downloads/
+        key: downloads-cache-metal-only-${{ env.DOWNLOAD_CACHE_VERSION }}
+    - name: Check Device.DEFAULT (METAL) and print some source
+      run: |
+        METAL=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'METAL', Device.DEFAULT"
+        METAL=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
+    - name: Run metal test
+      run: JIT=2 METAL=1 python -m pytest -n=auto test/ --ignore=test/external --ignore=test/models --durations=20
+    - name: Run real world test
+      run: JIT=2 METAL=1 python -m pytest -n=auto test/models/test_real_world.py --durations=20
+    - name: Run ONNX
+      run: JIT=2 METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
+    - name: Test tensor core ops (fake)
+      run: TC=2 METAL=1 DEBUG=3 python test/test_ops.py TestOps.test_gemm
+    - name: Test tensor core ops (real)
+      run: METAL=1 DEBUG=3 python test/test_ops.py TestOps.test_big_gemm
+    - name: Test LLaMA compile speed
+      run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py
+    - name: Test Beam Search
+      run: PYTHONPATH="." METAL=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
+    - name: Fuzz Test linearizer
+      run: PYTHONPATH="." METAL=1 CACHELEVEL=0 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=48 FUZZ_MAX_SIZE=10000000 python test/external/fuzz_linearizer.py
+    - name: Fuzz Test models schedule
+      run: FUZZ_SCHEDULE=1 FUZZ_SCHEDULE_MAX_PATHS=5 python -m pytest test/models/test_train.py test/models/test_end2end.py
+    - name: Run TRANSCENDENTAL math
+      run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
+    - name: Run process replay tests
+      run: |
+        export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
+        export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }})
+        cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
+
+#  testwebgl:
+#    name: WebGL Tests
+#    runs-on: ubuntu-latest
+#    timeout-minutes: 10
+#
+#    steps:
+#    - name: Checkout Code
+#      uses: actions/checkout@v3
+#    - name: Set up Python 3.11
+#      uses: actions/setup-python@v4
+#      with:
+#        python-version: 3.11
+#    - name: Cache python packages
+#      uses: actions/cache@v4
+#      with:
+#        path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
+#        key: webgl-testing-packages-${{ hashFiles('**/setup.py') }}
+#    - name: Install Dependencies
+#      run: pip install -e '.[webgl,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
+#    - name: Cache downloads
+#      uses: actions/cache@v4
+#      with:
+#        path: ~/Library/Caches/tinygrad/downloads/
+#        key: downloads-cache-webgl-${{ env.DOWNLOAD_CACHE_VERSION }}
+#    - name: Prepare
+#      run: |
+#        sudo apt-get -y install xvfb
+#        sudo /usr/bin/Xvfb :0 -screen 0 4096x4096x24+32 &
+#    - name: Run selected webgl tests
+#      run: WEBGL=1 python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_jit.py
+#    - name: Build WebGL Efficientnet
+#      run: WEBGL=1 python -m examples.compile_efficientnet
+
+  tests:
+    strategy:
+      fail-fast: false
+      matrix:
+        backend: [llvm, clang, gpu, ptx, amd, nv] #, triton]
+
+    name: Tests on (${{ matrix.backend }})
+    runs-on: ubuntu-latest
+    timeout-minutes: 20
+
+    steps:
+      - name: Checkout Code
+        uses: actions/checkout@v4
+        with:
+          fetch-depth: 2 # NOTE: this fetches the HEAD commit of the PR
+      - name: Set up Python 3.11
+        uses: actions/setup-python@v5
+        with:
+          python-version: 3.11
+      - name: Cache python packages
+        uses: actions/cache@v4
+        with:
+          path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
+          key: ${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }}
+      - name: Cache downloads
+        uses: actions/cache@v4
+        with:
+          path: ~/.cache/tinygrad/downloads/
+          key: downloads-cache-${{ matrix.backend }}-${{ env.DOWNLOAD_CACHE_VERSION }}
+      - name: Set env
+        run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nPTX=1\nMOCKGPU=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nMOCKGPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
+      - name: Install OpenCL
+        if: matrix.backend == 'gpu'
+        run: |
+          echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel
+          echo "deb [ allow-insecure=yes ] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list
+          sudo apt update || true
+          sudo apt install --allow-unauthenticated -y --no-install-recommends opencl-headers \
+            intel-oneapi-runtime-openmp=2023.2.1-16 intel-oneapi-runtime-compilers-common=2023.2.1-16 intel-oneapi-runtime-compilers=2023.2.1-16 \
+            intel-oneapi-runtime-dpcpp-sycl-opencl-cpu=2023.2.1-16 intel-oneapi-runtime-tbb-common=2021.10.0-49541 \
+            intel-oneapi-runtime-tbb=2021.10.0-49541 intel-oneapi-runtime-opencl=2023.2.1-16
+      - name: Install packages (cuda)
+        if: matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
+        run: |
+          echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel
+          sudo apt update -y || true
+          sudo apt install -y --no-install-recommends git g++ cmake ninja-build llvm-15-dev zlib1g-dev libglew-dev \
+            flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc libzstd-dev
+      - name: Cache gpuocelot
+        if: matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
+        id: cache-build
+        uses: actions/cache@v4
+        env:
+          cache-name: cache-gpuocelot-build
+        with:
+          path: ${{ github.workspace }}/gpuocelot/ocelot
+          key: ubuntu22.04-gpuocelot-4524e34adb7eaccc6f71262f2e21d7052bb17c2f-rebuild-7
+      - name: Clone/compile gpuocelot
+        if: (matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv') && steps.cache-build.outputs.cache-hit != 'true'
+        run: |
+          git clone --recurse-submodules https://github.com/gpuocelot/gpuocelot.git ${{ github.workspace }}/gpuocelot
+          cd ${{ github.workspace }}/gpuocelot/ocelot
+          git checkout 4524e34adb7eaccc6f71262f2e21d7052bb17c2f
+          mkdir build
+          cd build
+          cmake .. -Wno-dev -G Ninja -DOCELOT_BUILD_TOOLS=OFF -DCMAKE_BUILD_ALWAYS=0 -DBUILD_TESTS_CUDA=OFF
+          ninja
+      - name: Install gpuocelot
+        if: matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
+        run: |
+          cd ${{ github.workspace }}/gpuocelot/ocelot/build
+          sudo ninja install -d explain
+      - name: Install packages (amd)
+        if: matrix.backend == 'amd'
+        run: |
+          echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel
+          wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null
+          sudo tee /etc/apt/sources.list.d/rocm.list <<'EOF'
+          deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/6.1.2 jammy main
+          EOF
+          echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' | sudo tee /etc/apt/preferences.d/rocm-pin-600
+          sudo apt update || true
+          sudo apt install --no-install-recommends --allow-unauthenticated -y hsa-rocr comgr hsa-rocr-dev liburing-dev libc6-dev
+          curl -s https://api.github.com/repos/Qazalin/remu/releases/latest | \
+          jq -r '.assets[] | select(.name == "libremu.so").browser_download_url' | \
+          sudo xargs curl -L -o /usr/local/lib/libremu.so
+          sudo tee --append /etc/ld.so.conf.d/rocm.conf <<'EOF'
+            /opt/rocm/lib
+            /opt/rocm/lib64
+          EOF
+          sudo ldconfig
+      - name: Install dependencies
+        run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='ptx'&&',cuda'||matrix.backend=='triton'&&',triton'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
+      - name: Check Device.DEFAULT and print some source
+        run: |
+          PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','AMD','NV'], Device.DEFAULT"
+          DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
+      - name: Verify OpenCL autogen
+        if: matrix.backend == 'gpu'
+        run: |
+          cp tinygrad/runtime/autogen/opencl.py /tmp/opencl.py.bak
+          ./autogen_stubs.sh opencl
+          diff /tmp/opencl.py.bak tinygrad/runtime/autogen/opencl.py
+      - name: Verify CUDA autogen
+        if: matrix.backend == 'nv'
+        run: |
+          cp tinygrad/runtime/autogen/cuda.py /tmp/cuda.py.bak
+          cp tinygrad/runtime/autogen/nv_gpu.py /tmp/nv_gpu.py.bak
+          ./autogen_stubs.sh cuda
+          ./autogen_stubs.sh nv
+          diff /tmp/cuda.py.bak tinygrad/runtime/autogen/cuda.py
+          diff /tmp/nv_gpu.py.bak tinygrad/runtime/autogen/nv_gpu.py
+      - name: Verify AMD autogen
+        if: matrix.backend == 'amd'
+        run: |
+          cp tinygrad/runtime/autogen/hsa.py /tmp/hsa.py.bak
+          cp tinygrad/runtime/autogen/comgr.py /tmp/comgr.py.bak
+          cp tinygrad/runtime/autogen/amd_gpu.py /tmp/amd_gpu.py.bak
+          ./autogen_stubs.sh hsa
+          ./autogen_stubs.sh comgr
+          ./autogen_stubs.sh amd
+          diff /tmp/hsa.py.bak tinygrad/runtime/autogen/hsa.py
+          diff /tmp/comgr.py.bak tinygrad/runtime/autogen/comgr.py
+          diff /tmp/amd_gpu.py.bak tinygrad/runtime/autogen/amd_gpu.py
+      - name: Verify Linux autogen
+        if: matrix.backend == 'amd'
+        run: |
+          cp tinygrad/runtime/autogen/libc.py /tmp/libc.py.bak
+          cp tinygrad/runtime/autogen/io_uring.py /tmp/io_uring.py.bak
+          ./autogen_stubs.sh libc
+          ./autogen_stubs.sh io_uring
+          diff /tmp/libc.py.bak tinygrad/runtime/autogen/libc.py
+          diff /tmp/io_uring.py.bak tinygrad/runtime/autogen/io_uring.py
+      - name: Run pytest (not cuda or amd)
+        if: matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'amd' && matrix.backend != 'nv'
+        run: python -m pytest -n=auto test/ --durations=20
+      # - name: Run test_ops with FUZZ_UOPS=1
+      #   if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'amd' && matrix.backend != 'nv'
+      #   run: FUZZ_UOPS=1 python -m pytest -n=auto test/test_ops.py --durations=20
+      - name: Run ONNX (only LLVM)
+        if: matrix.backend == 'llvm'
+        run: python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
+      - name: Run pytest (cuda)
+        if: matrix.backend=='ptx'||matrix.backend=='triton'||matrix.backend=='nv'
+        run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors)' --ignore=test/external --ignore=test/models --ignore test/test_gc.py --durations=20
+      - name: Run pytest (amd)
+        if: matrix.backend=='amd'
+        run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/test_hcq.py --durations=20
+      - name: Run TRANSCENDENTAL math
+        run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
+      - name: Run process replay tests
+        run: |
+          export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
+          export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }})
+          cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
+
+  #testunicorn:
+  #  name: ARM64 unicorn Test
+  #  runs-on: ubuntu-latest
+  #  timeout-minutes: 10
+  #  steps:
+  #    - name: Checkout Code
+  #      uses: actions/checkout@v4
+  #    - name: Set up Python 3.11
+  #      uses: actions/setup-python@v5
+  #      with:
+  #        python-version: 3.11
+  #    - name: Cache python packages
+  #      uses: actions/cache@v4
+  #      with:
+  #        path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
+  #        key: testing-arm-packages-${{ hashFiles('**/setup.py') }}
+  #    - name: Install cross-assembler
+  #      run: |
+  #        sudo apt update -y
+  #        sudo apt install -y --no-install-recommends gcc-aarch64-linux-gnu
+  #    - name: Install dependencies
+  #      run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu
+  #    - name: Test arm
+  #      run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py  --ignore=test/unit/test_disk_tensor.py

+ 56 - 0
tinychat/.gitignore

@@ -0,0 +1,56 @@
+__pycache__
+.venv/
+.vscode
+.DS_Store
+notebooks
+.*.swp
+.*.swo
+*.pyc
+*.so
+*.txt
+build
+/dist
+*.egg-info
+/env
+a.out
+boxes.jpg
+pandecode.dump
+vertex.bin
+recognize*
+.idea
+*.prof
+extra/disassemblers/applegpu
+extra/datasets/cifar-10-python.tar.gz
+extra/datasets/librispeech/
+extra/datasets/imagenet/
+extra/datasets/wiki/
+extra/datasets/kits19
+extra/datasets/kits19/
+extra/datasets/squad/
+extra/datasets/img_align_celeba*
+extra/datasets/open-images-v6-mlperf
+extra/datasets/kits/
+extra/datasets/COCO/
+extra/datasets/audio*
+extra/weights
+venv
+examples/**/net.*[js,json]
+examples/**/*.safetensors
+node_modules
+package.json
+package-lock.json
+temp
+*.csv
+.coverage
+coverage.xml
+htmlcov
+outputs_yolov8
+wandb
+model.safetensors
+quickstart.py
+.hypothesis
+weights
+*.lprof
+comgr_*
+*.pkl
+site/

+ 51 - 0
tinychat/.pre-commit-config.yaml

@@ -0,0 +1,51 @@
+repos:
+  - repo: local
+    hooks:
+      - id: whitespace
+        name: strip whitespace
+        entry: find tinygrad -type f -name "*.py" -exec sed -i '' 's/ *$//' '{}' ';'
+        language: system
+        always_run: true
+        pass_filenames: false
+      - id: ruff
+        name: ruff
+        entry: ruff check .
+        language: system
+        always_run: true
+        pass_filenames: false
+      - id: mypy
+        name: mypy
+        entry: mypy tinygrad/ --strict-equality
+        language: system
+        always_run: true
+        pass_filenames: false
+      - id: docs2
+        name: docs2
+        entry: python3 docs/abstractions2.py
+        language: system
+        always_run: true
+        pass_filenames: false
+      - id: devicetests
+        name: select GPU tests
+        entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py test/test_search.py
+        language: system
+        always_run: true
+        pass_filenames: false
+      - id: tests
+        name: subset of tests
+        entry: env PYTHONPATH="." python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
+        language: system
+        always_run: true
+        pass_filenames: false
+      - id: example
+        name: multi device tests
+        entry: python3 test/external/external_test_example.py
+        language: system
+        always_run: true
+        pass_filenames: false
+      - id: pylint
+        name: pylint
+        entry: env PYTHONPATH="." python3 -m pylint tinygrad/
+        language: system
+        always_run: true
+        pass_filenames: false

+ 472 - 0
tinychat/.pylintrc

@@ -0,0 +1,472 @@
+[MASTER]
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code
+extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
+
+# Add files or directories to the blacklist. They should be base names, not
+# paths.
+ignore=CVS
+
+# Add files or directories matching the regex patterns to the blacklist. The
+# regex matches against base names, not paths.
+ignore-patterns=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint.
+jobs=4
+
+# List of plugins (as comma separated values of python modules names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Specify a configuration file.
+#rcfile=
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
+confidence=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once).You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use"--disable=all --enable=classes
+# --disable=W"
+disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401
+# E1101 for function binding
+# W0221 for Function class
+# W0105 for comment strings
+# E0401 for missing imports
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member,use-a-generator, no-else-return
+
+
+[REPORTS]
+
+# Python expression which should return a note less than 10 (10 is the highest
+# note). You have access to the variables errors warning, statement which
+# respectively contain the number of errors / warnings messages and the total
+# number of statements analyzed. This is used by the global evaluation report
+# (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details
+#msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio).You can also give a reporter class, eg
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Activate the evaluation score.
+score=yes
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=optparse.Values,sys.exit
+
+
+[LOGGING]
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format
+logging-modules=logging
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it working
+# install python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to indicated private dictionary in
+# --spelling-private-dict-file option instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+      XXX,
+      TODO
+
+
+[SIMILARITIES]
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis. It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+          _cb
+
+# A regular expression matching the name of dummy variables (i.e. expectedly
+# not used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )?<?https?://\S+>?$
+
+# Number of spaces of indent required inside a hanging  or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
+# tab).
+indent-string='  '
+
+# Maximum number of characters on a single line.
+max-line-length=150
+
+# Maximum number of lines in a module
+max-module-lines=1000
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[BASIC]
+
+# Naming style matching correct argument names
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style
+#argument-rgx=
+
+# Naming style matching correct attribute names
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma
+bad-names=foo,
+          bar,
+          baz,
+          toto,
+          tutu,
+          tata
+
+# Naming style matching correct class attribute names
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style
+#class-attribute-rgx=
+
+# Naming style matching correct class names
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-style
+#class-rgx=
+
+# Naming style matching correct constant names
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=i,
+           j,
+           k,
+           ex,
+           Run,
+           _
+
+# Include a hint for the correct naming format with invalid-name
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style
+#inlinevar-rgx=
+
+# Naming style matching correct method names
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style
+#method-rgx=
+
+# Naming style matching correct module names
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+property-classes=abc.abstractproperty
+
+# Naming style matching correct variable names
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style
+#variable-rgx=
+
+
+[DESIGN]
+
+# Maximum number of arguments for function / method
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in a if statement
+max-bool-expr=5
+
+# Maximum number of branch for function / method body
+max-branches=12
+
+# Maximum number of locals for function / method body
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body
+max-returns=6
+
+# Maximum number of statements in function / method body
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+                      __new__,
+                      setUp
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+                  _fields,
+                  _replace,
+                  _source,
+                  _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[IMPORTS]
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Deprecated modules which should not be used, separated by a comma
+deprecated-modules=regsub,
+                   TERMIOS,
+                   Bastion,
+                   rexec
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled)
+ext-import-graph=
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled)
+import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled)
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+[STRING]
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=yes
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "Exception"
+overgeneral-exceptions=builtins.Exception

+ 7 - 0
tinychat/LICENSE

@@ -0,0 +1,7 @@
+Copyright (c) 2024, the tiny corp
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

+ 178 - 0
tinychat/README.md

@@ -0,0 +1,178 @@
+<div align="center">
+
+<picture>
+  <source media="(prefers-color-scheme: light)" srcset="/docs/logo_tiny_light.svg">
+  <img alt="tiny corp logo" src="/docs/logo_tiny_dark.svg" width="50%" height="50%">
+</picture>
+
+tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) and [karpathy/micrograd](https://github.com/karpathy/micrograd). Maintained by [tiny corp](https://tinygrad.org).
+
+<h3>
+
+[Homepage](https://github.com/tinygrad/tinygrad) | [Documentation](https://docs.tinygrad.org/) | [Discord](https://discord.gg/ZjZadyC7PK)
+
+</h3>
+
+[![GitHub Repo stars](https://img.shields.io/github/stars/tinygrad/tinygrad)](https://github.com/tinygrad/tinygrad/stargazers)
+[![Unit Tests](https://github.com/tinygrad/tinygrad/actions/workflows/test.yml/badge.svg)](https://github.com/tinygrad/tinygrad/actions/workflows/test.yml)
+[![Discord](https://img.shields.io/discord/1068976834382925865)](https://discord.gg/ZjZadyC7PK)
+
+</div>
+
+---
+
+This may not be the best deep learning framework, but it is a deep learning framework.
+
+Due to its extreme simplicity, it aims to be the easiest framework to add new accelerators to, with support for both inference and training. If XLA is CISC, tinygrad is RISC.
+
+tinygrad is still alpha software, but we [raised some money](https://geohot.github.io/blog/jekyll/update/2023/05/24/the-tiny-corp-raised-5M.html) to make it good. Someday, we will tape out chips.
+
+## Features
+
+### LLaMA and Stable Diffusion
+
+tinygrad can run [LLaMA](/docs/showcase.md#llama) and [Stable Diffusion](/docs/showcase.md#stable-diffusion)!
+
+### Laziness
+
+Try a matmul. See how, despite the style, it is fused into one kernel with the power of laziness.
+
+```sh
+DEBUG=3 python3 -c "from tinygrad import Tensor;
+N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N);
+c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2);
+print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
+```
+
+And we can change `DEBUG` to `4` to see the generated code.
+
+### Neural networks
+
+As it turns out, 90% of what you need for neural networks are a decent autograd/tensor library.
+Throw in an optimizer, a data loader, and some compute, and you have all you need.
+
+```python
+from tinygrad import Tensor, nn
+
+class LinearNet:
+  def __init__(self):
+    self.l1 = Tensor.kaiming_uniform(784, 128)
+    self.l2 = Tensor.kaiming_uniform(128, 10)
+  def __call__(self, x:Tensor) -> Tensor:
+    return x.flatten(1).dot(self.l1).relu().dot(self.l2)
+
+model = LinearNet()
+optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
+
+x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7])  # replace with real mnist dataloader
+
+with Tensor.train():
+  for i in range(10):
+    optim.zero_grad()
+    loss = model(x).sparse_categorical_crossentropy(y).backward()
+    optim.step()
+    print(i, loss.item())
+```
+
+See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full version that gets 98% in ~5 seconds
+
+## Accelerators
+
+tinygrad already supports numerous accelerators, including:
+
+- [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
+- [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
+- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
+- [x] [METAL](tinygrad/runtime/ops_metal.py)
+- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
+- [x] [AMD](tinygrad/runtime/ops_amd.py)
+- [x] [NV](tinygrad/runtime/ops_nv.py)
+
+And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
+
+## Installation
+
+The current recommended way to install tinygrad is from source.
+
+### From source
+
+```sh
+git clone https://github.com/tinygrad/tinygrad.git
+cd tinygrad
+python3 -m pip install -e .
+```
+
+### Direct (master)
+
+```sh
+python3 -m pip install git+https://github.com/tinygrad/tinygrad.git
+```
+
+## Documentation
+
+Documentation along with a quick start guide can be found on the [docs website](https://docs.tinygrad.org/) built from the [docs/](/docs) directory.
+
+### Quick example comparing to PyTorch
+
+```python
+from tinygrad import Tensor
+
+x = Tensor.eye(3, requires_grad=True)
+y = Tensor([[2.0,0,-2.0]], requires_grad=True)
+z = y.matmul(x).sum()
+z.backward()
+
+print(x.grad.numpy())  # dz/dx
+print(y.grad.numpy())  # dz/dy
+```
+
+The same thing but in PyTorch:
+```python
+import torch
+
+x = torch.eye(3, requires_grad=True)
+y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
+z = y.matmul(x).sum()
+z.backward()
+
+print(x.grad.numpy())  # dz/dx
+print(y.grad.numpy())  # dz/dy
+```
+
+## Contributing
+
+There has been a lot of interest in tinygrad lately. Following these guidelines will help your PR get accepted.
+
+We'll start with what will get your PR closed with a pointer to this section:
+
+- No code golf! While low line count is a guiding light of this project, anything that remotely looks like code golf will be closed. The true goal is reducing complexity and increasing readability, and deleting `\n`s does nothing to help with that.
+- All docs and whitespace changes will be closed unless you are a well-known contributor. The people writing the docs should be those who know the codebase the absolute best. People who have not demonstrated that shouldn't be messing with docs. Whitespace changes are both useless *and* carry a risk of introducing bugs.
+- Anything you claim is a "speedup" must be benchmarked. In general, the goal is simplicity, so even if your PR makes things marginally faster, you have to consider the tradeoff with maintainablity and readablity.
+- In general, the code outside the core `tinygrad/` folder is not well tested, so unless the current code there is broken, you shouldn't be changing it.
+- If your PR looks "complex", is a big diff, or adds lots of lines, it won't be reviewed or merged. Consider breaking it up into smaller PRs that are individually clear wins. A common pattern I see is prerequisite refactors before adding new functionality. If you can (cleanly) refactor to the point that the feature is a 3 line change, this is great, and something easy for us to review.
+
+Now, what we want:
+
+- Bug fixes (with a regression test) are great! This library isn't 1.0 yet, so if you stumble upon a bug, fix it, write a test, and submit a PR, this is valuable work.
+- Solving bounties! tinygrad [offers cash bounties](https://docs.google.com/spreadsheets/d/1WKHbT-7KOgjEawq5h5Ic1qUWzpfAzuD_J06N1JwOCGs/edit?usp=sharing) for certain improvements to the library. All new code should be high quality and well tested.
+- Features. However, if you are adding a feature, consider the line tradeoff. If it's 3 lines, there's less of a bar of usefulness it has to meet over something that's 30 or 300 lines. All features must have regression tests. In general with no other constraints, your feature's API should match torch or numpy.
+- Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win. Refactors should pass [process replay](#process-replay-tests).
+- Tests/fuzzers. If you can add tests that are non brittle, they are welcome. We have some fuzzers in here too, and there's a plethora of bugs that can be found with them and by improving them. Finding bugs, even writing broken tests (that should pass) with `@unittest.expectedFailure` is great. This is how we make progress.
+- Dead code removal from core `tinygrad/` folder. We don't care about the code in extra, but removing dead code from the core library is great. Less for new people to read and be confused by.
+
+### Running tests
+
+You should install the pre-commit hooks with `pre-commit install`. This will run the linter, mypy, and a subset of the tests on every commit.
+
+For more examples on how to run the full test suite please refer to the [CI workflow](.github/workflows/test.yml).
+
+Some examples of running tests locally:
+```sh
+python3 -m pip install -e '.[testing]'  # install extra deps for testing
+python3 test/test_ops.py                # just the ops tests
+python3 -m pytest test/                 # whole test suite
+```
+
+#### Process replay tests
+
+[Process replay](https://github.com/tinygrad/tinygrad/blob/master/test/external/process_replay/process_replay.py) compares your PR's generated kernels against master. If your PR is a refactor or speedup without any expected behavior change, It should include [run_process_replay] in the PR title, [example](https://github.com/tinygrad/tinygrad/pull/4995). Note that you should keep your branch up-to-date with master.

+ 289 - 0
tinychat/autogen_stubs.sh

@@ -0,0 +1,289 @@
+#!/bin/bash -e
+
+# setup instructions for clang2py
+if [[ ! $(clang2py -V) ]]; then
+  pushd .
+  cd /tmp
+  sudo apt-get install -y --no-install-recommends clang
+  pip install --upgrade pip setuptools
+  pip install clang==14.0.6
+  git clone https://github.com/geohot/ctypeslib.git
+  cd ctypeslib
+  pip install --user .
+  clang2py -V
+  popd
+fi
+
+BASE=tinygrad/runtime/autogen/
+
+fixup() {
+  sed -i '1s/^/# mypy: ignore-errors\n/' $1
+  sed -i 's/ *$//' $1
+  grep FIXME_STUB $1 || true
+}
+
+patch_dlopen() {
+  path=$1; shift
+  name=$1; shift
+  cat <<EOF | sed -i "/import ctypes.*/r /dev/stdin" $path
+PATHS_TO_TRY = [
+$(for p in "$@"; do echo "  $p,"; done)
+]
+def _try_dlopen_$name():
+  library = ctypes.util.find_library("$name")
+  if library: return ctypes.CDLL(library)
+  for candidate in PATHS_TO_TRY:
+    try: return ctypes.CDLL(candidate)
+    except OSError: pass
+  raise RuntimeError("library $name not found")
+EOF
+}
+
+process_cdefines() {
+  local input_file="$1"
+  local output_file="$2"
+
+  sed -E '
+    # Remove single-line comments
+    s/[[:space:]]*\/\*.*\*\///g
+
+    # Remove multi-line comments
+    /\/\*/,/\*\//d
+
+    /.*DT_MIPS_NUM.*/d
+
+    # Remove lines ending with backslash (multi-line macros)
+    /\\$/d
+
+    # Convert C integer literals (remove U suffix)
+    s/\b([0-9]+)U\b/\1/g
+
+    # Convert C types to Python ctypes
+    s/\bunsigned char\b/ctypes.c_ubyte/g
+    s/\bsigned char\b/ctypes.c_byte/g
+    s/\bunsigned short\b/ctypes.c_ushort/g
+    s/\bshort\b/ctypes.c_short/g
+    s/\bunsigned int\b/ctypes.c_uint/g
+    s/\bint\b/ctypes.c_int/g
+    s/\bunsigned long\b/ctypes.c_ulong/g
+    s/\blong\b/ctypes.c_long/g
+    s/\bfloat\b/ctypes.c_float/g
+    s/\bdouble\b/ctypes.c_double/g
+
+    # Function-like macros with parameters
+    /^#define[[:space:]]+([[:alnum:]_]+)[[:space:]]*\(([^)]*)\)[[:space:]]+(.+)/ {
+      s//def \1(\2): return \3/
+      p
+      d
+    }
+
+    # Simple #define statements (including those with parentheses)
+    /^#define[[:space:]]+([[:alnum:]_]+)[[:space:]]+(.+)/ {
+      s//\1 = \2/
+      p
+      d
+    }
+
+    # Drop all other lines
+    d
+  ' "$input_file" >> "$output_file"
+}
+
+generate_opencl() {
+  clang2py /usr/include/CL/cl.h -o $BASE/opencl.py -l /usr/lib/x86_64-linux-gnu/libOpenCL.so.1 -k cdefstum
+  fixup $BASE/opencl.py
+  # hot patches
+  sed -i "s\import ctypes\import ctypes, ctypes.util\g" $BASE/opencl.py
+  sed -i "s\ctypes.CDLL('/usr/lib/x86_64-linux-gnu/libOpenCL.so.1')\ctypes.CDLL(ctypes.util.find_library('OpenCL'))\g" $BASE/opencl.py
+  python3 -c "import tinygrad.runtime.autogen.opencl"
+}
+
+generate_hip() {
+  clang2py /opt/rocm/include/hip/hip_ext.h /opt/rocm/include/hip/hiprtc.h \
+  /opt/rocm/include/hip/hip_runtime_api.h /opt/rocm/include/hip/driver_types.h \
+  --clang-args="-D__HIP_PLATFORM_AMD__ -I/opt/rocm/include -x c++" -o $BASE/hip.py -l /opt/rocm/lib/libamdhip64.so
+  echo "hipDeviceProp_t = hipDeviceProp_tR0600" >> $BASE/hip.py
+  echo "hipGetDeviceProperties = hipGetDevicePropertiesR0600" >> $BASE/hip.py
+  fixup $BASE/hip.py
+  # we can trust HIP is always at /opt/rocm/lib
+  #sed -i "s\import ctypes\import ctypes, ctypes.util\g" $BASE/hip.py
+  #sed -i "s\ctypes.CDLL('/opt/rocm/lib/libhiprtc.so')\ctypes.CDLL(ctypes.util.find_library('hiprtc'))\g" $BASE/hip.py
+  #sed -i "s\ctypes.CDLL('/opt/rocm/lib/libamdhip64.so')\ctypes.CDLL(ctypes.util.find_library('amdhip64'))\g" $BASE/hip.py
+  sed -i "s\import ctypes\import ctypes, os\g" $BASE/hip.py
+  sed -i "s\'/opt/rocm/\os.getenv('ROCM_PATH', '/opt/rocm/')+'/\g" $BASE/hip.py
+  python3 -c "import tinygrad.runtime.autogen.hip"
+}
+
+generate_comgr() {
+  clang2py /opt/rocm/include/amd_comgr/amd_comgr.h \
+  --clang-args="-D__HIP_PLATFORM_AMD__ -I/opt/rocm/include -x c++" -o $BASE/comgr.py -l /opt/rocm/lib/libamd_comgr.so
+  fixup $BASE/comgr.py
+  sed -i "s\import ctypes\import ctypes, ctypes.util, os\g" $BASE/comgr.py
+  patch_dlopen $BASE/comgr.py amd_comgr "'/opt/rocm/lib/libamd_comgr.so'" "os.getenv('ROCM_PATH', '')+'/lib/libamd_comgr.so'"
+  sed -i "s\ctypes.CDLL('/opt/rocm/lib/libamd_comgr.so')\_try_dlopen_amd_comgr()\g" $BASE/comgr.py
+  python3 -c "import tinygrad.runtime.autogen.comgr"
+}
+
+generate_kfd() {
+  clang2py /usr/include/linux/kfd_ioctl.h -o $BASE/kfd.py -k cdefstum
+  fixup $BASE/kfd.py
+  sed -i "s\import ctypes\import ctypes, os\g" $BASE/kfd.py
+  python3 -c "import tinygrad.runtime.autogen.kfd"
+}
+
+generate_cuda() {
+  clang2py /usr/include/cuda.h -o $BASE/cuda.py -l /usr/lib/x86_64-linux-gnu/libcuda.so
+  sed -i "s\import ctypes\import ctypes, ctypes.util\g" $BASE/cuda.py
+  sed -i "s\ctypes.CDLL('/usr/lib/x86_64-linux-gnu/libcuda.so')\ctypes.CDLL(ctypes.util.find_library('cuda'))\g" $BASE/cuda.py
+  fixup $BASE/cuda.py
+  python3 -c "import tinygrad.runtime.autogen.cuda"
+}
+
+generate_nvrtc() {
+  clang2py /usr/local/cuda/include/nvrtc.h /usr/local/cuda/include/nvJitLink.h -o $BASE/nvrtc.py -l /usr/local/cuda/lib64/libnvrtc.so -l /usr/local/cuda/lib64/libnvJitLink.so
+  sed -i "s\import ctypes\import ctypes, ctypes.util\g" $BASE/nvrtc.py
+  sed -i "s\ctypes.CDLL('/usr/local/cuda/lib64/libnvrtc.so')\ctypes.CDLL(ctypes.util.find_library('nvrtc'))\g" $BASE/nvrtc.py
+  sed -i "s\ctypes.CDLL('/usr/local/cuda/lib64/libnvJitLink.so')\ctypes.CDLL(ctypes.util.find_library('nvJitLink'))\g" $BASE/nvrtc.py
+  fixup $BASE/nvrtc.py
+  python3 -c "import tinygrad.runtime.autogen.nvrtc"
+}
+
+generate_nv() {
+  NVKERN_COMMIT_HASH=d6b75a34094b0f56c2ccadf14e5d0bd515ed1ab6
+  NVKERN_SRC=/tmp/open-gpu-kernel-modules-$NVKERN_COMMIT_HASH
+  if [ ! -d "$NVKERN_SRC" ]; then
+    git clone https://github.com/tinygrad/open-gpu-kernel-modules $NVKERN_SRC
+    pushd .
+    cd $NVKERN_SRC
+    git reset --hard $NVKERN_COMMIT_HASH
+    popd
+  fi
+
+  clang2py \
+    extra/nv_gpu_driver/clc6c0qmd.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/class/cl0080.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/class/cl2080_notification.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/class/clc56f.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/class/clc56f.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/class/clc56f.h \
+    $NVKERN_SRC/src/nvidia/generated/g_allclasses.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/class/clc6c0.h \
+    $NVKERN_SRC/kernel-open/nvidia-uvm/clc6b5.h \
+    $NVKERN_SRC/kernel-open/nvidia-uvm/uvm_ioctl.h \
+    $NVKERN_SRC/kernel-open/nvidia-uvm/uvm_linux_ioctl.h \
+    $NVKERN_SRC/src/nvidia/arch/nvalloc/unix/include/nv_escape.h \
+    $NVKERN_SRC/src/nvidia/arch/nvalloc/unix/include/nv-ioctl.h \
+    $NVKERN_SRC/src/nvidia/arch/nvalloc/unix/include/nv-ioctl-numbers.h \
+    $NVKERN_SRC/src/nvidia/arch/nvalloc/unix/include/nv-ioctl-numa.h \
+    $NVKERN_SRC/src/nvidia/arch/nvalloc/unix/include/nv-unix-nvos-params-wrappers.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/alloc/alloc_channel.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/nvos.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl/ctrl0000/*.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl/ctrl0080/*.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl/ctrl2080/*.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl/ctrl83de/*.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl/ctrlc36f.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl/ctrlcb33.h \
+    $NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl/ctrla06c.h \
+    --clang-args="-include $NVKERN_SRC/src/common/sdk/nvidia/inc/nvtypes.h -I$NVKERN_SRC/src/common/inc -I$NVKERN_SRC/kernel-open/nvidia-uvm -I$NVKERN_SRC/kernel-open/common/inc -I$NVKERN_SRC/src/common/sdk/nvidia/inc -I$NVKERN_SRC/src/nvidia/arch/nvalloc/unix/include -I$NVKERN_SRC/src/common/sdk/nvidia/inc/ctrl" \
+    -o $BASE/nv_gpu.py -k cdefstum
+  fixup $BASE/nv_gpu.py
+  sed -i "s\(0000000001)\1\g" $BASE/nv_gpu.py
+  sed -i "s\import ctypes\import ctypes, os\g" $BASE/nv_gpu.py
+  sed -i 's/#\?\s\([A-Za-z0-9_]\+\) = MW ( \([0-9]\+\) : \([0-9]\+\) )/\1 = (\2 , \3)/' $BASE/nv_gpu.py # NVC6C0_QMDV03_00 processing
+  sed -i 's/#\sdef NVC6C0_QMD\([A-Za-z0-9_()]\+\):/def NVC6C0_QMD\1:/' $BASE/nv_gpu.py
+  sed -i 's/#\s*return MW(\([0-9i()*+]\+\):\([0-9i()*+]\+\))/    return (\1 , \2)/' $BASE/nv_gpu.py
+  sed -i 's/#\?\s*\(.*\)\s*=\s*\(NV\)\?BIT\(32\)\?\s*(\s*\([0-9]\+\)\s*)/\1 = (1 << \4)/' $BASE/nv_gpu.py # name = BIT(x) -> name = (1 << x)
+  sed -i "s/UVM_\([A-Za-z0-9_]\+\) = \['i', '(', '\([0-9]\+\)', ')'\]/UVM_\1 = \2/" $BASE/nv_gpu.py # UVM_name = ['i', '(', '<num>', ')'] -> UVM_name = <num>
+
+  # Parse status codes
+  sed -n '1i\
+nv_status_codes = {}
+/^NV_STATUS_CODE/ { s/^NV_STATUS_CODE(\([^,]*\), *\([^,]*\), *"\([^"]*\)") *.*$/\1 = \2\nnv_status_codes[\1] = "\3"/; p }' $NVKERN_SRC/src/common/sdk/nvidia/inc/nvstatuscodes.h >> $BASE/nv_gpu.py
+
+  python3 -c "import tinygrad.runtime.autogen.nv_gpu"
+}
+
+generate_amd() {
+  # clang2py broken when pass -x c++ to prev headers
+  clang2py extra/hip_gpu_driver/sdma_registers.h \
+    --clang-args="-I/opt/rocm/include -x c++" \
+    -o $BASE/amd_gpu.py
+
+  sed 's/^\(.*\)\(\s*\/\*\)\(.*\)$/\1 #\2\3/; s/^\(\s*\*\)\(.*\)$/#\1\2/' extra/hip_gpu_driver/nvd.h >> $BASE/amd_gpu.py # comments
+  sed 's/^\(.*\)\(\s*\/\*\)\(.*\)$/\1 #\2\3/; s/^\(\s*\*\)\(.*\)$/#\1\2/' extra/hip_gpu_driver/sdma_v6_0_0_pkt_open.h >> $BASE/amd_gpu.py # comments
+  sed -i 's/#\s*define\s*\([^ \t]*\)(\([^)]*\))\s*\(.*\)/def \1(\2): return \3/' $BASE/amd_gpu.py # #define name(x) (smth) -> def name(x): return (smth)
+  sed -i '/#\s*define\s\+\([^ \t]\+\)\s\+\([^ ]\+\)/s//\1 = \2/' $BASE/amd_gpu.py # #define name val -> name = val
+
+  sed -e '/^reg/s/^\(reg[^ ]*\) [^ ]* \([^ ]*\) .*/\1 = \2/' \
+    -e '/^ix/s/^\(ix[^ ]*\) [^ ]* \([^ ]*\) .*/\1 = \2/' \
+    -e '/^[ \t]/d' \
+    extra/hip_gpu_driver/gc_11_0_0.reg >> $BASE/amd_gpu.py
+
+  fixup $BASE/amd_gpu.py
+  sed -i "s\import ctypes\import ctypes, os\g" $BASE/amd_gpu.py
+  python3 -c "import tinygrad.runtime.autogen.amd_gpu"
+}
+
+generate_hsa() {
+  clang2py \
+    /opt/rocm/include/hsa/hsa.h \
+    /opt/rocm/include/hsa/hsa_ext_amd.h \
+    /opt/rocm/include/hsa/amd_hsa_signal.h \
+    /opt/rocm/include/hsa/amd_hsa_queue.h \
+    /opt/rocm/include/hsa/amd_hsa_kernel_code.h \
+    /opt/rocm/include/hsa/hsa_ext_finalize.h /opt/rocm/include/hsa/hsa_ext_image.h \
+    /opt/rocm/include/hsa/hsa_ven_amd_aqlprofile.h \
+    --clang-args="-I/opt/rocm/include" \
+    -o $BASE/hsa.py -l /opt/rocm/lib/libhsa-runtime64.so
+
+  fixup $BASE/hsa.py
+  sed -i "s\import ctypes\import ctypes, ctypes.util, os\g" $BASE/hsa.py
+  sed -i "s\ctypes.CDLL('/opt/rocm/lib/libhsa-runtime64.so')\ctypes.CDLL(os.getenv('ROCM_PATH')+'/lib/libhsa-runtime64.so' if os.getenv('ROCM_PATH') else ctypes.util.find_library('hsa-runtime64'))\g" $BASE/hsa.py
+  python3 -c "import tinygrad.runtime.autogen.hsa"
+}
+
+generate_io_uring() {
+  clang2py \
+    /usr/include/liburing.h \
+    /usr/include/linux/io_uring.h \
+    -o $BASE/io_uring.py
+
+  # clang2py can't parse defines
+  sed -r '/^#define __NR_io_uring/ s/^#define __(NR_io_uring[^ ]+) (.*)$/\1 = \2/; t; d' /usr/include/asm-generic/unistd.h >> $BASE/io_uring.py # io_uring syscalls numbers
+  sed -r '/^#define\s+([^ \t]+)\s+([^ \t]+)/ s/^#define\s+([^ \t]+)\s*([^/]*).*$/\1 = \2/; s/1U/1/g; s/0ULL/0/g; t; d' /usr/include/linux/io_uring.h >> $BASE/io_uring.py # #define name (val) -> name = val
+
+  fixup $BASE/io_uring.py
+}
+
+generate_libc() {
+  clang2py \
+    $(dpkg -L libc6-dev | grep sys/mman.h) \
+    $(dpkg -L libc6-dev | grep sys/syscall.h) \
+    /usr/include/elf.h \
+    /usr/include/unistd.h \
+    -o $BASE/libc.py
+
+  process_cdefines "/usr/include/elf.h" "$BASE/libc.py"
+
+  sed -i "s\import ctypes\import ctypes, ctypes.util, os\g" $BASE/libc.py
+  sed -i "s\FIXME_STUB\libc\g" $BASE/libc.py
+  sed -i "s\FunctionFactoryStub()\ctypes.CDLL(ctypes.util.find_library('c'))\g" $BASE/libc.py
+
+  fixup $BASE/libc.py
+}
+
+if [ "$1" == "opencl" ]; then generate_opencl
+elif [ "$1" == "hip" ]; then generate_hip
+elif [ "$1" == "comgr" ]; then generate_comgr
+elif [ "$1" == "cuda" ]; then generate_cuda
+elif [ "$1" == "nvrtc" ]; then generate_nvrtc
+elif [ "$1" == "hsa" ]; then generate_hsa
+elif [ "$1" == "kfd" ]; then generate_kfd
+elif [ "$1" == "nv" ]; then generate_nv
+elif [ "$1" == "amd" ]; then generate_amd
+elif [ "$1" == "io_uring" ]; then generate_io_uring
+elif [ "$1" == "libc" ]; then generate_libc
+elif [ "$1" == "all" ]; then generate_opencl; generate_hip; generate_comgr; generate_cuda; generate_nvrtc; generate_hsa; generate_kfd; generate_nv; generate_amd; generate_io_uring; generate_libc
+else echo "usage: $0 <type>"
+fi

+ 1 - 0
tinychat/docs/CNAME

@@ -0,0 +1 @@
+docs.tinygrad.org

+ 118 - 0
tinychat/docs/abstractions2.py

@@ -0,0 +1,118 @@
+# tinygrad is a tensor library, and as a tensor library it has multiple parts
+# 1. a "runtime". this allows buffer management, compilation, and running programs
+# 2. a "Device" that uses the runtime but specifies compute in an abstract way for all
+# 3. a "LazyBuffer" that fuses the compute into kernels, using memory only when needed
+# 4. a "Tensor" that provides an easy to use frontend with autograd ".backward()"
+
+
+print("******** first, the runtime ***********")
+
+from tinygrad.runtime.ops_clang import ClangProgram, ClangCompiler, MallocAllocator
+
+# allocate some buffers
+out = MallocAllocator.alloc(4)
+a = MallocAllocator.alloc(4)
+b = MallocAllocator.alloc(4)
+
+# load in some values (little endian)
+MallocAllocator.copyin(a, bytearray([2,0,0,0]))
+MallocAllocator.copyin(b, bytearray([3,0,0,0]))
+
+# compile a program to a binary
+lib = ClangCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
+
+# create a runtime for the program (ctypes.CDLL)
+fxn = ClangProgram("add", lib)
+
+# run the program
+fxn(out, a, b)
+
+# check the data out
+print(val := MallocAllocator.as_buffer(out).cast("I").tolist()[0])
+assert val == 5
+
+
+print("******** second, the Device ***********")
+
+DEVICE = "CLANG"   # NOTE: you can change this!
+
+import struct
+from tinygrad.dtype import dtypes
+from tinygrad.device import Buffer, Device
+from tinygrad.ops import LazyOp, BufferOps, MemBuffer, BinaryOps, MetaOps
+from tinygrad.shape.shapetracker import ShapeTracker
+
+# allocate some buffers + load in values
+out = Buffer(DEVICE, 1, dtypes.int32).allocate()
+a = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
+b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
+# NOTE: a._buf is the same as the return from MallocAllocator.alloc
+
+# describe the computation
+ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
+ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
+alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
+st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
+sink = LazyOp(MetaOps.KERNEL, (st_0,))
+
+# convert the computation to a "linearized" format (print the format)
+from tinygrad.engine.realize import get_kernel, CompiledRunner
+lin = get_kernel(Device[DEVICE].renderer, sink).linearize()
+for u in lin.uops: print(u)
+
+# compile a program (and print the source)
+fxn = CompiledRunner(lin.to_program())
+print(fxn.p.src)
+# NOTE: fxn.clprg is the ClangProgram
+
+# run the program
+fxn.exec([out, a, b])
+
+# check the data out
+assert out.as_buffer().cast('I')[0] == 5
+
+
+print("******** third, the LazyBuffer ***********")
+
+from tinygrad.lazy import LazyBuffer
+from tinygrad.engine.realize import run_schedule
+from tinygrad.engine.schedule import create_schedule
+
+# allocate some values + load in values
+a = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE)
+b = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE)
+a.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
+b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
+del a.srcs
+del b.srcs
+
+# describe the computation
+out = a.e(BinaryOps.ADD, b)
+
+# schedule the computation as a list of kernels
+sched = create_schedule([out])
+for si in sched: print(si.ast.op)  # NOTE: the first two convert it to CLANG
+
+# DEBUGGING: print the compute ast as a tree
+from tinygrad.engine.graph import print_tree
+print_tree(sched[-1].ast)
+# NOTE: sched[-1].ast is the same as st_0 above
+
+# run that schedule
+run_schedule(sched)
+
+# check the data out
+assert out.realized.as_buffer().cast('I')[0] == 5
+
+
+print("******** fourth, the Tensor ***********")
+
+from tinygrad import Tensor
+
+a = Tensor([2], dtype=dtypes.int32, device=DEVICE)
+b = Tensor([3], dtype=dtypes.int32, device=DEVICE)
+out = a + b
+
+# check the data out
+print(val:=out.item())
+assert val == 5

+ 62 - 0
tinychat/docs/abstractions3.py

@@ -0,0 +1,62 @@
+# abstractions2 goes from back to front, here we will go from front to back
+from typing import List
+from tqdm import tqdm
+from tinygrad.helpers import DEBUG
+
+# *****
+# 0. Load mnist on the device
+
+from tinygrad.nn.datasets import mnist
+X_train, Y_train, _, _ = mnist()
+X_train = X_train.float()
+X_train -= X_train.mean()
+
+# *****
+# 1. Define an MNIST model.
+
+from tinygrad import Tensor
+
+l1 = Tensor.kaiming_uniform(128, 784)
+l2 = Tensor.kaiming_uniform(10, 128)
+def model(x): return x.flatten(1).dot(l1.T).relu().dot(l2.T)
+l1n, l2n = l1.numpy(), l2.numpy()
+
+# *****
+# 2. Choose a batch for training and do the backward pass.
+
+from tinygrad.nn.optim import SGD
+optim = SGD([l1, l2])
+
+X, Y = X_train[(samples:=Tensor.randint(128, high=X_train.shape[0]))], Y_train[samples]
+optim.zero_grad()
+model(X).sparse_categorical_crossentropy(Y).backward()
+optim._step()   # this will step the optimizer without running realize
+
+# *****
+# 3. Create a schedule.
+
+# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
+# l1.lazydata and l2.lazydata define a computation graph
+
+from tinygrad.engine.schedule import ScheduleItem
+schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)
+
+print(f"The schedule contains {len(schedule)} items.")
+for si in schedule: print(str(si)[:80])
+
+# *****
+# 4. Lower a schedule.
+
+from tinygrad.engine.realize import lower_schedule_item, ExecItem
+lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si).prg, list(si.bufs)) for si in tqdm(schedule)]
+
+# *****
+# 5. Run the schedule
+
+for ei in tqdm(lowered): ei.run()
+
+# *****
+# 6. Print the weight change
+
+print("first weight change\n", l1.numpy()-l1n)
+print("second weight change\n", l2.numpy()-l2n)

+ 56 - 0
tinychat/docs/developer.md

@@ -0,0 +1,56 @@
+The tinygrad framework has four pieces
+
+* a PyTorch like <b>frontend</b>.
+* a <b>scheduler</b> which breaks the compute into kernels.
+* a <b>lowering</b> engine which converts ASTs into code that can run on the accelerator.
+* an <b>execution</b> engine which can run that code.
+
+There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-notes/) by Di Zhu that go over tinygrad internals.
+
+## Frontend
+
+Everything in [Tensor](tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of:
+
+::: tinygrad.lazy.LazyBuffer
+    options:
+        show_source: false
+
+The `LazyBuffer` graph specifies the compute in terms of low level tinygrad ops. Not all LazyBuffers will actually become realized. There's two types of LazyBuffers, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base.
+
+## Scheduling
+
+The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. One `ScheduleItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
+
+::: tinygrad.engine.schedule.ScheduleItem
+
+## Lowering
+
+The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
+
+::: tinygrad.engine.realize.lower_schedule
+
+There's a ton of complexity hidden behind this, see the `codegen/` directory.
+
+First we lower the AST to UOps, which is a linear list of the compute to be run. This is where the BEAM search happens.
+
+Then we render the UOps into code with a `Renderer`, then we compile the code to binary with a `Compiler`.
+
+## Execution
+
+Creating `ExecItem`, which has a run method
+
+::: tinygrad.engine.realize.ExecItem
+    options:
+        members: true
+
+Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?)
+
+## Runtime
+
+Runtimes are responsible for device-specific interactions. They handle tasks such as initializing devices, allocating memory, loading/launching programs, and more. You can find more information about the runtimes API on the [runtime overview page](runtime/overview.md).
+
+All runtime implementations can be found in the [runtime directory](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime).
+
+### HCQ Compatible Runtimes
+
+HCQ API is a lower-level API for defining runtimes. Interaction with HCQ-compatible devices occurs at a lower level, with commands issued directly to hardware queues. Some examples of such backends are [NV](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_nv.py) and [AMD](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_amd.py), which are userspace drivers for NVIDIA and AMD devices respectively. You can find more information about the API on [HCQ overview page](runtime/hcq.md)

+ 9 - 0
tinychat/docs/dtypes.md

@@ -0,0 +1,9 @@
+::: tinygrad.dtype.DType
+
+::: tinygrad.dtype.dtypes
+    options:
+        members: true
+        members_order: source
+        show_labels: false
+
+::: tinygrad.dtype.ConstType

+ 52 - 0
tinychat/docs/env_vars.md

@@ -0,0 +1,52 @@
+# List of environment variables that control tinygrad behavior.
+
+This is a list of environment variable that control the runtime behavior of tinygrad and its examples.
+Most of these are self-explanatory, and are usually used to set an option at runtime.
+
+Example: `GPU=1 DEBUG=4 python3 -m pytest`
+
+However you can also decorate a function to set a value only inside that function.
+
+```python
+# in tensor.py (probably only useful if you are a tinygrad developer)
+@Context(DEBUG=4)
+def numpy(self) -> ...
+```
+
+Or use contextmanager to temporarily set a value inside some scope:
+
+```python
+with Context(DEBUG=0):
+  a = Tensor.ones(10, 10)
+  a *= 2
+```
+
+## Global Variables
+The columns of this list are are: Variable, Possible Value(s) and Description.
+
+- A `#` means that the variable can take any integer value.
+
+These control the behavior of core tinygrad even when used as a library.
+
+Variable | Possible Value(s) | Description
+---|---|---
+DEBUG               | [1-6]      | enable debugging output, with 4 you get operations, timings, speed, generated code and more
+GPU                 | [1]        | enable the GPU backend
+CUDA                | [1]        | enable CUDA backend
+AMD                 | [1]        | enable AMD backend
+NV                  | [1]        | enable NV backend
+METAL               | [1]        | enable Metal backend (for Mac M1 and after)
+METAL_XCODE         | [1]        | enable Metal using macOS Xcode SDK
+CLANG               | [1]        | enable Clang backend
+LLVM                | [1]        | enable LLVM backend
+BEAM                | [#]        | number of beams in kernel beam search
+GRAPH               | [1]        | create a graph of all operations (requires graphviz)
+GRAPHUOPS           | [1]        | create a graph of uops (requires graphviz and saves at /tmp/uops.{svg,dot})
+GRAPHPATH           | [/path/to] | where to put the generated graph
+DEFAULT_FLOAT       | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32
+IMAGE               | [1-2]      | enable 2d specific optimizations
+FLOAT16             | [1]        | use float16 for images instead of float32
+PTX                 | [1]        | enable the specialized [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/) assembler for Nvidia GPUs. If not set, defaults to generic CUDA codegen backend.
+PROFILE             | [1]        | enable output of [perfetto](https://ui.perfetto.dev/) compatible profile. This feature is supported in NV and AMD backends.
+VISIBLE_DEVICES     | [list[int]]| restricts the NV/AMD devices that are available. The format is a comma-separated list of identifiers (indexing starts with 0).
+JIT                 | [0-2]      | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled

+ 25 - 0
tinychat/docs/favicon.svg

@@ -0,0 +1,25 @@
+<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges">
+  <g id="logo">
+    <!-- t -->
+    <polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
+    <!-- i -->
+    <polygon points="40,40 40,20 50,20 50,40" />
+    <polygon points="40,10 40,0 50,0 50,10" />
+    <!-- n -->
+    <polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
+    <!-- y -->
+    <polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
+  </g>
+  <style>
+  @media (prefers-color-scheme: dark) {
+    #logo {
+      fill: #fff;
+    }
+  }
+  @media (prefers-color-scheme: light) {
+    #logo {
+      fill: #000;
+    }
+  }
+  </style>
+</svg>

+ 33 - 0
tinychat/docs/function.md

@@ -0,0 +1,33 @@
+::: tinygrad.function
+    options:
+        members: [
+            "Contiguous",
+            "ContiguousBackward",
+            "Cast",
+            "Neg",
+            "Reciprocal",
+            "Sin",
+            "Relu",
+            "Log",
+            "Exp",
+            "Sqrt",
+            "Sigmoid",
+            "Sign",
+            "Less",
+            "Eq",
+            "Xor",
+            "Add",
+            "Sub",
+            "Mul",
+            "Div",
+            "Where",
+            "Sum",
+            "Max",
+            "Expand",
+            "Reshape",
+            "Permute",
+            "Pad",
+            "Shrink",
+            "Flip",
+        ]
+        show_source: false

+ 47 - 0
tinychat/docs/index.md

@@ -0,0 +1,47 @@
+# tinygrad documentation
+
+Welcome to the docs for tinygrad. This page is for users of the tinygrad library. tinygrad is not 1.0 yet, but it will be soon. The API has been pretty stable for a while.
+
+While you can `pip install tinygrad`, we encourage you to install from source:
+
+```bash
+git clone https://github.com/tinygrad/tinygrad.git
+cd tinygrad
+python3 -m pip install -e .
+```
+
+After you have installed tinygrad, try the [MNIST tutorial](mnist.md)
+
+We also have [developer docs](developer.md), and Di Zhu has created a [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-notes/) to help understand how tinygrad works.
+
+## tinygrad Usage
+
+The main class you will interact with is [Tensor](tensor/index.md). It functions very similarly to PyTorch, but has a bit more of a functional style. tinygrad supports [many datatypes](dtypes.md).  All operations in tinygrad are lazy, meaning they won't do anything until you realize.
+
+* tinygrad has a built in [neural network library](nn.md) with some classes, optimizers, and load/save state management.
+* tinygrad has a JIT to make things fast. Decorate your pure function with `TinyJit`
+* tinygrad has amazing support for multiple GPUs, allowing you to shard your Tensors with `Tensor.shard`
+
+To understand what training looks like in tinygrad, you should read `beautiful_mnist.py`
+
+We have a [quickstart guide](quickstart.md) and a [showcase](showcase.md)
+
+## Differences from PyTorch
+
+If you are migrating from PyTorch, welcome. Most of the API is the same. We hope you will find tinygrad both familiar and somehow more "correct feeling"
+
+### tinygrad doesn't have nn.Module
+
+There's nothing special about a "Module" class in tinygrad, it's just a normal class. [`nn.state.get_parameters`](nn.md/#tinygrad.nn.state.get_parameters) can be used to recursively search normal classes for valid tensors. Instead of the `forward` method in PyTorch, tinygrad just uses `__call__`
+
+### tinygrad is functional
+
+In tinygrad, you can do [`x.conv2d(w, b)`](tensor/ops.md/#tinygrad.Tensor.conv2d) or [`x.sparse_categorical_cross_entropy(y)`](tensor/ops.md/#tinygrad.Tensor.sparse_categorical_crossentropy). We do also have a [`Conv2D`](nn.md/#tinygrad.nn.Conv2d) class like PyTorch if you want a place to keep the state, but all stateless operations don't have classes.
+
+### tinygrad is lazy
+
+When you do `a+b` in tinygrad, nothing happens. It's not until you [`realize`](tensor/index.md/#tinygrad.Tensor.realize) the Tensor that the computation actually runs.
+
+### tinygrad requires @TinyJit to be fast
+
+PyTorch spends a lot of development effort to make dispatch very fast. tinygrad doesn't. We have a simple decorator that will replay the kernels used in the decorated function.

+ 11 - 0
tinychat/docs/logo_tiny_dark.svg

@@ -0,0 +1,11 @@
+<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges" fill="#fff">
+  <!-- t -->
+  <polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
+  <!-- i -->
+  <polygon points="40,40 40,20 50,20 50,40" />
+  <polygon points="40,10 40,0 50,0 50,10" />
+  <!-- n -->
+  <polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
+  <!-- y -->
+  <polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
+</svg>

+ 11 - 0
tinychat/docs/logo_tiny_light.svg

@@ -0,0 +1,11 @@
+<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges">
+  <!-- t -->
+  <polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
+  <!-- i -->
+  <polygon points="40,40 40,20 50,20 50,40" />
+  <polygon points="40,10 40,0 50,0 50,10" />
+  <!-- n -->
+  <polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
+  <!-- y -->
+  <polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
+</svg>

+ 177 - 0
tinychat/docs/mnist.md

@@ -0,0 +1,177 @@
+# MNIST Tutorial
+
+After you have installed tinygrad, this is a great first tutorial.
+
+Start up a notebook locally, or use [colab](https://colab.research.google.com/). tinygrad is very lightweight, so it's easy to install anywhere and doesn't need a special colab image, but for speed we recommend a T4 GPU image.
+
+### One-liner to install tinygrad in colab
+
+```python
+!pip install git+https://github.com/tinygrad/tinygrad.git
+```
+
+### What's the default device?
+
+```python
+from tinygrad import Device
+print(Device.DEFAULT)
+```
+
+You will see `CUDA` here on a GPU instance, or `CLANG` here on a CPU instance.
+
+## A simple model
+
+We'll use the model from [the Keras tutorial](https://keras.io/examples/vision/mnist_convnet/).
+
+```python
+from tinygrad import Tensor, nn
+
+class Model:
+  def __init__(self):
+    self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3))
+    self.l2 = nn.Conv2d(32, 64, kernel_size=(3,3))
+    self.l3 = nn.Linear(1600, 10)
+
+  def __call__(self, x:Tensor) -> Tensor:
+    x = self.l1(x).relu().max_pool2d((2,2))
+    x = self.l2(x).relu().max_pool2d((2,2))
+    return self.l3(x.flatten(1).dropout(0.5))
+```
+
+Two key differences from PyTorch:
+
+* Only the stateful layers are declared in `__init__`
+* There's no `nn.Module` class or `forward` function, just a normal class and `__call__`
+
+### Getting the dataset
+
+```python
+from tinygrad.nn.datasets import mnist
+X_train, Y_train, X_test, Y_test = mnist()
+print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)
+# (60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar
+```
+
+tinygrad includes MNIST, it only adds four lines. Feel free to read the [function](https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/datasets.py).
+
+## Using the model
+
+MNIST is small enough that the `mnist()` function copies the dataset to the default device.
+
+So creating the model and evaluating it is a matter of:
+
+```python
+model = Model()
+acc = (model(X_test).argmax(axis=1) == Y_test).mean()
+# NOTE: tinygrad is lazy, and hasn't actually run anything by this point
+print(acc.item())  # ~10% accuracy, as expected from a random model
+```
+
+### Training the model
+
+We'll use the Adam optimizer. The `nn.state.get_parameters` will walk the model class and pull out the parameters for the optimizer. Also, in tinygrad, it's typical to write a function to do the training step so it can be jitted.
+
+```python
+optim = nn.optim.Adam(nn.state.get_parameters(model))
+batch_size = 128
+def step():
+  Tensor.training = True  # makes dropout work
+  samples = Tensor.randint(batch_size, high=X_train.shape[0])
+  X, Y = X_train[samples], Y_train[samples]
+  optim.zero_grad()
+  loss = model(X).sparse_categorical_crossentropy(Y).backward()
+  optim.step()
+  return loss
+```
+
+You can time a step with:
+
+```python
+import timeit
+timeit.repeat(step, repeat=5, number=1)
+#[0.08268719699981375,
+# 0.07478952900009972,
+# 0.07714716600003158,
+# 0.07785399599970333,
+# 0.07605237000007037]
+```
+
+So around 75 ms on T4 colab.
+
+### Why so slow?
+
+Unlike PyTorch, tinygrad isn't designed to be fast like that. While 75 ms for one step is plenty fast for debugging, it's not great for training. Here, we introduce the first quintessentially tinygrad concept, the `TinyJit`.
+
+```python
+from tinygrad import TinyJit
+jit_step = TinyJit(step)
+```
+
+NOTE: It can also be used as a decorator `@TinyJit`
+
+Now when we time it:
+
+```python
+import timeit
+timeit.repeat(jit_step, repeat=5, number=1)
+# [0.2596786549997887,
+#  0.08989566299987928,
+#  0.0012115650001760514,
+#  0.001010227999813651,
+#  0.0012164899999334011]
+```
+
+1.0 ms is 75x faster! Note that we aren't syncing the GPU, so GPU time may be slower.
+
+The slowness the first two times is the JIT capturing the kernels. And this JIT will not run any Python in the function, it will just replay the tinygrad kernels that were run, so be aware that non tinygrad Python operations won't work. Randomness functions work as expected.
+
+Unlike other JITs, we JIT everything, including the optimizer. Think of it as a dumb replay on different data.
+
+## Putting it together
+
+Since we are just randomly sampling from the dataset, there's no real concept of an epoch. We have a batch size of 128, so the Keras example is taking about 7000 steps.
+
+```python
+for step in range(7000):
+  loss = jit_step()
+  if step%100 == 0:
+    Tensor.training = False
+    acc = (model(X_test).argmax(axis=1) == Y_test).mean().item()
+    print(f"step {step:4d}, loss {loss.item():.2f}, acc {acc*100.:.2f}%")
+```
+
+It doesn't take long to reach 98%, and it usually reaches 99%.
+
+```
+step    0, loss 4.03, acc 71.43%
+step  100, loss 0.34, acc 93.86%
+step  200, loss 0.23, acc 95.97%
+step  300, loss 0.18, acc 96.32%
+step  400, loss 0.18, acc 96.76%
+step  500, loss 0.13, acc 97.46%
+step  600, loss 0.14, acc 97.45%
+step  700, loss 0.10, acc 97.27%
+step  800, loss 0.23, acc 97.49%
+step  900, loss 0.13, acc 97.51%
+step 1000, loss 0.13, acc 97.88%
+step 1100, loss 0.11, acc 97.72%
+step 1200, loss 0.14, acc 97.65%
+step 1300, loss 0.12, acc 98.04%
+step 1400, loss 0.25, acc 98.17%
+step 1500, loss 0.11, acc 97.86%
+step 1600, loss 0.21, acc 98.21%
+step 1700, loss 0.14, acc 98.34%
+...
+```
+
+## From here?
+
+tinygrad is yours to play with now. It's pure Python and short, so unlike PyTorch, fixing library bugs is well within your abilities.
+
+- It's two lines to add multiGPU support to this example (can you find them?). You have to `.shard` the model to all GPUs, and `.shard` the dataset by batch.
+- `with Context(DEBUG=2)` shows the running kernels, `DEBUG=4` shows the code. All `Context` variables can also be environment variables.
+- `with Context(BEAM=2)` will do a BEAM search on the kernels, searching many possible implementations for what runs the fastest on your hardware. After this search, tinygrad is usually speed competitive with PyTorch, and the results are cached so you won't have to search next time.
+
+[Join our Discord](https://discord.gg/ZjZadyC7PK) for help, and if you want to be a tinygrad developer. Please read the Discord rules when you get there.
+
+[Follow us on Twitter](https://twitter.com/__tinygrad__) to keep up with the project.

+ 31 - 0
tinychat/docs/nn.md

@@ -0,0 +1,31 @@
+## Neural Network classes
+
+::: tinygrad.nn.BatchNorm
+::: tinygrad.nn.Conv1d
+::: tinygrad.nn.Conv2d
+::: tinygrad.nn.ConvTranspose1d
+::: tinygrad.nn.ConvTranspose2d
+::: tinygrad.nn.Linear
+::: tinygrad.nn.GroupNorm
+::: tinygrad.nn.InstanceNorm
+::: tinygrad.nn.LayerNorm
+::: tinygrad.nn.LayerNorm2d
+::: tinygrad.nn.RMSNorm
+::: tinygrad.nn.Embedding
+
+## Optimizers
+
+::: tinygrad.nn.optim.SGD
+::: tinygrad.nn.optim.LARS
+::: tinygrad.nn.optim.AdamW
+::: tinygrad.nn.optim.Adam
+::: tinygrad.nn.optim.LAMB
+
+## Load/Save
+
+::: tinygrad.nn.state.safe_load
+::: tinygrad.nn.state.safe_save
+::: tinygrad.nn.state.get_state_dict
+::: tinygrad.nn.state.get_parameters
+::: tinygrad.nn.state.load_state_dict
+::: tinygrad.nn.state.torch_load

+ 308 - 0
tinychat/docs/quickstart.md

@@ -0,0 +1,308 @@
+# Quick Start Guide
+
+This guide assumes no prior knowledge of pytorch or any other deep learning framework, but does assume some basic knowledge of neural networks.
+It is intended to be a very quick overview of the high level API that tinygrad provides.
+
+This guide is also structured as a tutorial which at the end of it you will have a working model that can classify handwritten digits.
+
+We need some imports to get started:
+
+```python
+import numpy as np
+from tinygrad.helpers import Timing
+```
+
+## Tensors
+
+Tensors are the base data structure in tinygrad. They can be thought of as a multidimensional array of a specific data type.
+All high level operations in tinygrad operate on these tensors.
+
+The tensor class can be imported like so:
+
+```python
+from tinygrad import Tensor
+```
+
+Tensors can be created from an existing data structure like a python list or numpy ndarray:
+
+```python
+t1 = Tensor([1, 2, 3, 4, 5])
+na = np.array([1, 2, 3, 4, 5])
+t2 = Tensor(na)
+```
+
+Tensors can also be created using one of the many factory methods:
+
+```python
+full = Tensor.full(shape=(2, 3), fill_value=5) # create a tensor of shape (2, 3) filled with 5
+zeros = Tensor.zeros(2, 3) # create a tensor of shape (2, 3) filled with 0
+ones = Tensor.ones(2, 3) # create a tensor of shape (2, 3) filled with 1
+
+full_like = Tensor.full_like(full, fill_value=2) # create a tensor of the same shape as `full` filled with 2
+zeros_like = Tensor.zeros_like(full) # create a tensor of the same shape as `full` filled with 0
+ones_like = Tensor.ones_like(full) # create a tensor of the same shape as `full` filled with 1
+
+eye = Tensor.eye(3) # create a 3x3 identity matrix
+arange = Tensor.arange(start=0, stop=10, step=1) # create a tensor of shape (10,) filled with values from 0 to 9
+
+rand = Tensor.rand(2, 3) # create a tensor of shape (2, 3) filled with random values from a uniform distribution
+randn = Tensor.randn(2, 3) # create a tensor of shape (2, 3) filled with random values from a standard normal distribution
+uniform = Tensor.uniform(2, 3, low=0, high=10) # create a tensor of shape (2, 3) filled with random values from a uniform distribution between 0 and 10
+```
+
+There are even more of these factory methods, you can find them in the [Tensor Creation](tensor/creation.md) file.
+
+All the tensors creation methods can take a `dtype` argument to specify the data type of the tensor, find the supported `dtype` in [dtypes](dtypes.md).
+
+```python
+from tinygrad import dtypes
+
+t3 = Tensor([1, 2, 3, 4, 5], dtype=dtypes.int32)
+```
+
+Tensors allow you to perform operations on them like so:
+
+```python
+t4 = Tensor([1, 2, 3, 4, 5])
+t5 = (t4 + 1) * 2
+t6 = (t5 * t4).relu().log_softmax()
+```
+
+All of these operations are lazy and are only executed when you realize the tensor using `.realize()` or `.numpy()`.
+
+```python
+print(t6.numpy())
+# [-56. -48. -36. -20.   0.]
+```
+
+There are a lot more operations that can be performed on tensors, you can find them in the [Tensor Ops](tensor/ops.md) file.
+Additionally reading through [abstractions2.py](https://github.com/tinygrad/tinygrad/blob/master/docs/abstractions2.py) will help you understand how operations on these tensors make their way down to your hardware.
+
+## Models
+
+Neural networks in tinygrad are really just represented by the operations performed on tensors.
+These operations are commonly grouped into the `__call__` method of a class which allows modularization and reuse of these groups of operations.
+These classes do not need to inherit from any base class, in fact if they don't need any trainable parameters they don't even need to be a class!
+
+An example of this would be the `nn.Linear` class which represents a linear layer in a neural network.
+
+```python
+class Linear:
+  def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
+    self.weight = getattr(Tensor, initialization)(out_features, in_features)
+    self.bias = Tensor.zeros(out_features) if bias else None
+
+  def __call__(self, x):
+    return x.linear(self.weight.transpose(), self.bias)
+```
+
+There are more neural network modules already implemented in [nn](nn.md), and you can also implement your own.
+
+We will be implementing a simple neural network that can classify handwritten digits from the MNIST dataset.
+Our classifier will be a simple 2 layer neural network with a Leaky ReLU activation function.
+It will use a hidden layer size of 128 and an output layer size of 10 (one for each digit) with no bias on either Linear layer.
+
+```python
+class TinyNet:
+  def __init__(self):
+    self.l1 = Linear(784, 128, bias=False)
+    self.l2 = Linear(128, 10, bias=False)
+
+  def __call__(self, x):
+    x = self.l1(x)
+    x = x.leakyrelu()
+    x = self.l2(x)
+    return x
+
+net = TinyNet()
+```
+
+We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor `x`.
+We can also see that functional operations like `leakyrelu` are not defined as classes and instead are just methods we can just call.
+Finally, we just initialize an instance of our neural network, and we are ready to start training it.
+
+## Training
+
+Now that we have our neural network defined we can start training it.
+Training neural networks in tinygrad is super simple.
+All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
+They can then be used to update the parameters of our neural network using one of the many [Optimizers](nn.md#optimizers).
+
+For our loss function we will be using sparse categorical cross entropy loss. The implementation below is taken from [tensor.py](https://github.com/tinygrad/tinygrad/blob/master/tinygrad/tensor.py), it's copied below to highlight an important detail of tinygrad.
+
+```python
+def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
+    loss_mask = Y != ignore_index
+    y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
+    y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
+    return self.log_softmax().mul(y).sum() / loss_mask.sum()
+```
+
+As we can see in this implementation of cross entropy loss, there are certain operations that tinygrad does not support natively.
+Load/store ops are not supported in tinygrad natively because they add complexity when trying to port to different backends, 90% of the models out there don't use/need them, and they can be implemented like it's done above with an `arange` mask.
+
+For our optimizer we will be using the traditional stochastic gradient descent optimizer with a learning rate of 3e-4.
+
+```python
+from tinygrad.nn.optim import SGD
+
+opt = SGD([net.l1.weight, net.l2.weight], lr=3e-4)
+```
+
+We can see that we are passing in the parameters of our neural network to the optimizer.
+This is due to the fact that the optimizer needs to know which parameters to update.
+There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.nn.state` which will return a list of all the parameters in the neural network.
+The parameters are just listed out explicitly here for clarity.
+
+Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
+There are a couple of dataset loaders in tinygrad located in [/extra/datasets](https://github.com/tinygrad/tinygrad/blob/master/extra/datasets).
+We will be using the MNIST dataset loader.
+
+```python
+from extra.datasets import fetch_mnist
+```
+
+Now we have everything we need to start training our neural network.
+We will be training for 1000 steps with a batch size of 64.
+
+We use `with Tensor.train()` set the internal flag `Tensor.training` to `True` during training.
+Upon exit, the flag is restored to its previous value by the context manager.
+
+```python
+X_train, Y_train, X_test, Y_test = fetch_mnist()
+
+with Tensor.train():
+  for step in range(1000):
+    # random sample a batch
+    samp = np.random.randint(0, X_train.shape[0], size=(64))
+    batch = Tensor(X_train[samp], requires_grad=False)
+    # get the corresponding labels
+    labels = Tensor(Y_train[samp])
+
+    # forward pass
+    out = net(batch)
+
+    # compute loss
+    loss = sparse_categorical_crossentropy(out, labels)
+
+    # zero gradients
+    opt.zero_grad()
+
+    # backward pass
+    loss.backward()
+
+    # update parameters
+    opt.step()
+
+    # calculate accuracy
+    pred = out.argmax(axis=-1)
+    acc = (pred == labels).mean()
+
+    if step % 100 == 0:
+      print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc.numpy()}")
+```
+
+## Evaluation
+
+Now that we have trained our neural network we can evaluate it on the test set.
+We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.
+
+```python
+with Timing("Time: "):
+  avg_acc = 0
+  for step in range(1000):
+    # random sample a batch
+    samp = np.random.randint(0, X_test.shape[0], size=(64))
+    batch = Tensor(X_test[samp], requires_grad=False)
+    # get the corresponding labels
+    labels = Y_test[samp]
+
+    # forward pass
+    out = net(batch)
+
+    # calculate accuracy
+    pred = out.argmax(axis=-1).numpy()
+    avg_acc += (pred == labels).mean()
+  print(f"Test Accuracy: {avg_acc / 1000}")
+```
+
+## And that's it
+
+Highly recommend you check out the [examples/](https://github.com/tinygrad/tinygrad/blob/master/examples) folder for more examples of using tinygrad.
+Reading the source code of tinygrad is also a great way to learn how it works.
+Specifically the tests in [test/](https://github.com/tinygrad/tinygrad/blob/master/test) are a great place to see how to use and the semantics of the different operations.
+There are also a bunch of models implemented in [models/](https://github.com/tinygrad/tinygrad/blob/master/extra/models) that you can use as a reference.
+
+Additionally, feel free to ask questions in the `#learn-tinygrad` channel on the [discord](https://discord.gg/beYbxwxVdx). Don't ask to ask, just ask!
+
+## Extras
+
+### JIT
+
+Additionally, it is possible to speed up the computation of certain neural networks by using the JIT.
+Currently, this does not support models with varying input sizes and non tinygrad operations.
+
+To use the JIT we just need to add a function decorator to the forward pass of our neural network and ensure that the input and output are realized tensors.
+Or in this case we will create a wrapper function and decorate the wrapper function to speed up the evaluation of our neural network.
+
+```python
+from tinygrad import TinyJit
+
+@TinyJit
+def jit(x):
+  return net(x).realize()
+
+with Timing("Time: "):
+  avg_acc = 0
+  for step in range(1000):
+    # random sample a batch
+    samp = np.random.randint(0, X_test.shape[0], size=(64))
+    batch = Tensor(X_test[samp], requires_grad=False)
+    # get the corresponding labels
+    labels = Y_test[samp]
+
+    # forward pass with jit
+    out = jit(batch)
+
+    # calculate accuracy
+    pred = out.argmax(axis=-1).numpy()
+    avg_acc += (pred == labels).mean()
+  print(f"Test Accuracy: {avg_acc / 1000}")
+```
+
+You will find that the evaluation time is much faster than before and that your accelerator utilization is much higher.
+
+### Saving and Loading Models
+
+The standard weight format for tinygrad is [safetensors](https://github.com/huggingface/safetensors). This means that you can load the weights of any model also using safetensors into tinygrad.
+There are functions in [state.py](https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/state.py) to save and load models to and from this format.
+
+```python
+from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
+
+# first we need the state dict of our model
+state_dict = get_state_dict(net)
+
+# then we can just save it to a file
+safe_save(state_dict, "model.safetensors")
+
+# and load it back in
+state_dict = safe_load("model.safetensors")
+load_state_dict(net, state_dict)
+```
+
+Many of the models in the [models/](https://github.com/tinygrad/tinygrad/tree/master/extra/models) folder have a `load_from_pretrained` method that will download and load the weights for you. These usually are pytorch weights meaning that you would need pytorch installed to load them.
+
+### Environment Variables
+
+There exist a bunch of environment variables that control the runtime behavior of tinygrad.
+Some of the commons ones are `DEBUG` and the different backend enablement variables.
+
+You can find a full list and their descriptions in [env_vars.md](env_vars.md).
+
+### Visualizing the Computation Graph
+
+It is possible to visualize the computation graph of a neural network using [graphviz](https://graphviz.org/).
+
+This is easily done by running a single pass (forward or backward!) of the neural network with the environment variable `GRAPH` set to `1`.
+The graph will be saved to `/tmp/net.svg` by default.

+ 146 - 0
tinychat/docs/runtime/hcq.md

@@ -0,0 +1,146 @@
+# HCQ Compatible Runtime
+
+## Overview
+
+The main aspect of HCQ-compatible runtimes is how they interact with devices. In HCQ, all interactions with devices occur in a hardware-friendly manner using [command queues](#commandqueues). This approach allows commands to be issued directly to devices, bypassing runtime overhead such as HIP or CUDA. Additionally, by using the HCQ API, these runtimes can benefit from various optimizations and features, including [HCQGraph](#hcqgraph) and built-in profiling capabilities.
+
+### Command Queues
+
+To interact with devices, there are 2 types of queues: `HWComputeQueue` and `HWCopyQueue`. Commands which are defined in a base `HWCommandQueue` class should be supported by both queues. These methods are timestamp and synchronization methods like [signal](#tinygrad.device.HWCommandQueue.signal) and [wait](#tinygrad.device.HWCommandQueue.wait).
+
+For example, the following Python code enqueues a wait, execute, and signal command on the HCQ-compatible device:
+```python
+HWComputeQueue().wait(signal_to_wait, value_to_wait) \
+                .exec(program, kernargs_ptr, global_dims, local_dims) \
+                .signal(signal_to_fire, value_to_fire) \
+                .submit(your_device)
+```
+
+Each runtime should implement the required functions that are defined in the `HWCommandQueue`, `HWComputeQueue`, and `HWCopyQueue` classes.
+
+::: tinygrad.device.HWCommandQueue
+    options:
+        members: [
+            "signal",
+            "wait",
+            "timestamp",
+            "update_signal",
+            "update_wait",
+            "submit",
+        ]
+        show_source: false
+
+::: tinygrad.device.HWComputeQueue
+    options:
+        members: [
+            "memory_barrier",
+            "exec",
+            "update_exec",
+        ]
+        show_source: false
+
+::: tinygrad.device.HWCopyQueue
+    options:
+        members: [
+            "copy",
+            "update_copy",
+        ]
+        show_source: false
+
+#### Implementing custom commands
+
+To implement custom commands in the queue, use the @hcq_command decorator for your command implementations.
+
+::: tinygrad.device.hcq_command
+    options:
+        members: [
+            "copy",
+            "update_copy",
+        ]
+        show_source: false
+
+### HCQ Compatible Device
+
+The `HCQCompatCompiled` class defines the API for HCQ-compatible devices. This class serves as an abstract base class that device-specific implementations should inherit from and implement.
+
+::: tinygrad.device.HCQCompatCompiled
+    options:
+        members: [
+            "_alloc_signal",
+            "_free_signal",
+            "_read_signal",
+            "_read_timestamp",
+            "_set_signal",
+            "_wait_signal",
+            "_gpu2cpu_time",
+        ]
+        show_source: false
+
+#### Signals
+
+Signals are device-dependent structures used for synchronization and timing in HCQ-compatible devices. They should be designed to record both a `value` and a `timestamp` within the same signal. The following Python code demonstrates the usage of signals:
+
+```python
+signal = your_device._alloc_signal()
+
+HWComputeQueue().timestamp(signal) \
+                .signal(signal, value_to_fire) \
+                .submit(your_device)
+
+your_device._wait_signal(signal, value_to_fire)
+timestamp = your_device._read_timestamp()
+```
+
+##### Synchronization signals
+
+Each HCQ-compatible device must allocate two signals for global synchronization purposes. These signals are passed to the `HCQCompatCompiled` base class during initialization: an active timeline signal `self.timeline_signal` and a shadow timeline signal `self._shadow_timeline_signal` which helps to handle signal value overflow issues. You can find more about synchronization in the [synchronization section](#synchronization)
+
+### HCQ Compatible Allocator
+
+The `HCQCompatAllocator` base class simplifies allocator logic by leveraging [command queues](#commandqueues) abstractions. This class efficiently handles copy and transfer operations, leaving only the alloc and free functions to be implemented by individual backends.
+
+::: tinygrad.device.HCQCompatAllocator
+    options:
+        members: [
+            "_alloc",
+            "_free",
+        ]
+        show_source: false
+
+#### HCQ Allocator Result Protocol
+
+Backends must adhere to the `HCQCompatAllocRes` protocol when returning allocation results.
+
+::: tinygrad.device.HCQCompatAllocRes
+    options:
+        members: true
+        show_source: false
+
+### HCQ Compatible Program
+
+The `HCQCompatProgram` is a helper base class for defining programs compatible with HCQ-compatible devices. Currently, the arguments consist of pointers to buffers, followed by `vals` fields. The convention expects a packed struct containing the passed pointers, followed by `vals` located at `kernargs_args_offset`.
+
+::: tinygrad.device.HCQCompatProgram
+    options:
+        members: true
+        show_source: false
+
+### Synchronization
+
+HCQ-compatible devices use a global timeline signal for synchronizing all operations. This mechanism ensures proper ordering and completion of tasks across the device. By convention, `self.timeline_value` points to the next value to signal. So, to wait for all previous operations on the device to complete, wait for `self.timeline_value - 1` value. The following Python code demonstrates the typical usage of signals to synchronize execution to other operations on the device:
+
+```python
+HWComputeQueue().wait(your_device.timeline_signal, your_device.timeline_value - 1) \
+                .exec(...)
+                .signal(your_device.timeline_signal, your_device.timeline_value) \
+                .submit(your_device)
+your_device.timeline_value += 1
+
+# Optionally wait for execution
+your_device._wait_signal(your_device.timeline_signal, your_device.timeline_value - 1)
+```
+
+## HCQGraph
+
+[HCQGraph](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/graph/hcq.py) is a core feature that implements `GraphRunner` for HCQ-compatible devices. `HCQGraph` builds a static `HWComputeQueue` and `HWCopyQueue` for all operations per device. To optimize enqueue time, only the necessary parts of the queues are updated for each run using the update APIs of the queues, avoiding a complete rebuild.
+Optionally, queues can implement a `bind` API, which allows further optimization by eliminating the need to copy the queues into the device ring.

+ 51 - 0
tinychat/docs/runtime/overview.md

@@ -0,0 +1,51 @@
+# Runtime Overview
+
+## Overview
+
+A typical runtime consists of the following parts:
+
+- [Compiled](#device)
+- [Allocator](#allocator)
+- [Program](#program)
+- [Compiler](#compiler)
+
+### Compiled
+
+The `Compiled` class is responsible for initializing and managing a device.
+
+::: tinygrad.device.Compiled
+    options:
+        members: [
+            "synchronize"
+        ]
+        show_source: false
+
+### Allocator
+
+The `Allocator` class is responsible for managing memory on the device. There is also a version called the `LRUAllocator`, which caches allocated buffers to optimize performance.
+
+::: tinygrad.device.Allocator
+    options:
+        members: true
+        show_source: false
+
+::: tinygrad.device.LRUAllocator
+    options:
+        members: true
+        show_source: false
+
+### Program
+
+The `Program` class is created for each loaded program. It is responsible for compiling and executing the program on the device. As an exmaple, here is a `ClangProgram` implmentation which loads program and runs it.
+
+::: tinygrad.runtime.ops_clang.ClangProgram
+    options:
+        members: true
+
+### Compiler
+
+The `Compiler` class compiles the output from the `Renderer` and produces it in a device-specific format.
+
+::: tinygrad.device.Compiler
+    options:
+        members: true

+ 62 - 0
tinychat/docs/showcase.md

@@ -0,0 +1,62 @@
+# Showcase
+
+Despite being a tiny library, tinygrad is capable of doing a lot of things. From state-of-the-art [vision](https://arxiv.org/abs/1905.11946) to state-of-the-art [language](https://arxiv.org/abs/1706.03762) models.
+
+## Vision
+
+### EfficientNet
+
+You can either pass in the URL of a picture to discover what it is:
+```sh
+python3 examples/efficientnet.py ./test/models/efficientnet/Chicken.jpg
+```
+Or, if you have a camera and OpenCV installed, you can detect what is in front of you:
+```sh
+python3 examples/efficientnet.py webcam
+```
+
+### YOLOv8
+
+Take a look at [yolov8.py](https://github.com/tinygrad/tinygrad/tree/master/examples/yolov8.py).
+
+![yolov8 by tinygrad](https://github.com/tinygrad/tinygrad/blob/master/docs/showcase/yolov8_showcase_image.png?raw=true)
+
+## Audio
+
+### Whisper
+
+Take a look at [whisper.py](https://github.com/tinygrad/tinygrad/tree/master/examples/whisper.py). You need pyaudio and torchaudio installed.
+
+```sh
+SMALL=1 python3 examples/whisper.py
+```
+
+## Generative
+
+### Stable Diffusion
+
+```sh
+python3 examples/stable_diffusion.py
+```
+
+![a horse sized cat eating a bagel](https://github.com/tinygrad/tinygrad/blob/master/docs/showcase/stable_diffusion_by_tinygrad.jpg?raw=true)
+
+*"a horse sized cat eating a bagel"*
+
+### LLaMA
+
+You will need to download and put the weights into the `weights/LLaMA` directory, which may need to be created.
+
+Then you can have a chat with Stacy:
+```sh
+python3 examples/llama.py
+```
+
+### Conversation
+
+Make sure you have espeak installed and `PHONEMIZER_ESPEAK_LIBRARY` set.
+
+Then you can talk to Stacy:
+```sh
+python3 examples/conversation.py
+```

BIN=BIN
tinychat/docs/showcase/mnist_by_tinygrad.jpg


BIN=BIN
tinychat/docs/showcase/stable_diffusion_by_tinygrad.jpg


BIN=BIN
tinychat/docs/showcase/yolo_by_tinygrad.jpg


BIN=BIN
tinychat/docs/showcase/yolov8_showcase_image.png


+ 24 - 0
tinychat/docs/tensor/creation.md

@@ -0,0 +1,24 @@
+## Creation (basic)
+
+::: tinygrad.Tensor.empty
+::: tinygrad.Tensor.zeros
+::: tinygrad.Tensor.ones
+::: tinygrad.Tensor.full
+::: tinygrad.Tensor.arange
+::: tinygrad.Tensor.eye
+::: tinygrad.Tensor.full_like
+::: tinygrad.Tensor.zeros_like
+::: tinygrad.Tensor.ones_like
+
+## Creation (random)
+
+::: tinygrad.Tensor.manual_seed
+::: tinygrad.Tensor.rand
+::: tinygrad.Tensor.randn
+::: tinygrad.Tensor.randint
+::: tinygrad.Tensor.normal
+::: tinygrad.Tensor.uniform
+::: tinygrad.Tensor.scaled_uniform
+::: tinygrad.Tensor.glorot_uniform
+::: tinygrad.Tensor.kaiming_uniform
+::: tinygrad.Tensor.kaiming_normal

+ 36 - 0
tinychat/docs/tensor/index.md

@@ -0,0 +1,36 @@
+# Tensor
+
+::: tinygrad.Tensor
+    options:
+        heading_level: 2
+        members: false
+        show_source: false
+
+## Properties
+
+::: tinygrad.Tensor.shape
+::: tinygrad.Tensor.dtype
+::: tinygrad.Tensor.device
+
+## Data Access
+
+::: tinygrad.Tensor.data
+::: tinygrad.Tensor.item
+::: tinygrad.Tensor.tolist
+::: tinygrad.Tensor.numpy
+
+## tinygrad ops
+
+::: tinygrad.Tensor.schedule_with_vars
+::: tinygrad.Tensor.schedule
+::: tinygrad.Tensor.realize
+::: tinygrad.Tensor.replace
+::: tinygrad.Tensor.assign
+::: tinygrad.Tensor.detach
+::: tinygrad.Tensor.to
+::: tinygrad.Tensor.to_
+::: tinygrad.Tensor.shard
+::: tinygrad.Tensor.shard_
+::: tinygrad.Tensor.contiguous
+::: tinygrad.Tensor.contiguous_backward
+::: tinygrad.Tensor.backward

+ 26 - 0
tinychat/docs/tensor/movement.md

@@ -0,0 +1,26 @@
+## Movement (low level)
+
+::: tinygrad.Tensor.view
+::: tinygrad.Tensor.reshape
+::: tinygrad.Tensor.expand
+::: tinygrad.Tensor.permute
+::: tinygrad.Tensor.flip
+::: tinygrad.Tensor.shrink
+::: tinygrad.Tensor.pad
+
+## Movement (high level)
+
+::: tinygrad.Tensor.gather
+::: tinygrad.Tensor.cat
+::: tinygrad.Tensor.stack
+::: tinygrad.Tensor.repeat
+::: tinygrad.Tensor.repeat_interleave
+::: tinygrad.Tensor.split
+::: tinygrad.Tensor.chunk
+::: tinygrad.Tensor.squeeze
+::: tinygrad.Tensor.unsqueeze
+::: tinygrad.Tensor.pad2d
+::: tinygrad.Tensor.T
+::: tinygrad.Tensor.transpose
+::: tinygrad.Tensor.flatten
+::: tinygrad.Tensor.unflatten

+ 113 - 0
tinychat/docs/tensor/ops.md

@@ -0,0 +1,113 @@
+## Reduce Ops
+
+::: tinygrad.Tensor.sum
+::: tinygrad.Tensor.max
+::: tinygrad.Tensor.min
+::: tinygrad.Tensor.any
+::: tinygrad.Tensor.all
+::: tinygrad.Tensor.mean
+::: tinygrad.Tensor.var
+::: tinygrad.Tensor.std
+::: tinygrad.Tensor.softmax
+::: tinygrad.Tensor.log_softmax
+::: tinygrad.Tensor.logsumexp
+::: tinygrad.Tensor.argmax
+::: tinygrad.Tensor.argmin
+
+## Processing Ops
+
+::: tinygrad.Tensor.avg_pool2d
+::: tinygrad.Tensor.max_pool2d
+::: tinygrad.Tensor.conv2d
+::: tinygrad.Tensor.conv_transpose2d
+::: tinygrad.Tensor.dot
+::: tinygrad.Tensor.matmul
+::: tinygrad.Tensor.einsum
+::: tinygrad.Tensor.cumsum
+::: tinygrad.Tensor.triu
+::: tinygrad.Tensor.tril
+::: tinygrad.Tensor.interpolate
+
+## Unary Ops (math)
+
+::: tinygrad.Tensor.logical_not
+::: tinygrad.Tensor.neg
+::: tinygrad.Tensor.log
+::: tinygrad.Tensor.log2
+::: tinygrad.Tensor.exp
+::: tinygrad.Tensor.exp2
+::: tinygrad.Tensor.sqrt
+::: tinygrad.Tensor.rsqrt
+::: tinygrad.Tensor.sin
+::: tinygrad.Tensor.cos
+::: tinygrad.Tensor.tan
+::: tinygrad.Tensor.trunc
+::: tinygrad.Tensor.ceil
+::: tinygrad.Tensor.floor
+::: tinygrad.Tensor.round
+::: tinygrad.Tensor.lerp
+::: tinygrad.Tensor.square
+::: tinygrad.Tensor.clip
+::: tinygrad.Tensor.sign
+::: tinygrad.Tensor.abs
+::: tinygrad.Tensor.reciprocal
+
+## Unary Ops (activation)
+
+::: tinygrad.Tensor.relu
+::: tinygrad.Tensor.sigmoid
+::: tinygrad.Tensor.elu
+::: tinygrad.Tensor.celu
+::: tinygrad.Tensor.swish
+::: tinygrad.Tensor.silu
+::: tinygrad.Tensor.relu6
+::: tinygrad.Tensor.hardswish
+::: tinygrad.Tensor.tanh
+::: tinygrad.Tensor.sinh
+::: tinygrad.Tensor.cosh
+::: tinygrad.Tensor.atanh
+::: tinygrad.Tensor.asinh
+::: tinygrad.Tensor.acosh
+::: tinygrad.Tensor.hardtanh
+::: tinygrad.Tensor.gelu
+::: tinygrad.Tensor.quick_gelu
+::: tinygrad.Tensor.leakyrelu
+::: tinygrad.Tensor.mish
+::: tinygrad.Tensor.softplus
+::: tinygrad.Tensor.softsign
+
+## Elementwise Ops (broadcasted)
+
+::: tinygrad.Tensor.add
+::: tinygrad.Tensor.sub
+::: tinygrad.Tensor.mul
+::: tinygrad.Tensor.div
+::: tinygrad.Tensor.xor
+::: tinygrad.Tensor.lshift
+::: tinygrad.Tensor.rshift
+::: tinygrad.Tensor.pow
+::: tinygrad.Tensor.maximum
+::: tinygrad.Tensor.minimum
+::: tinygrad.Tensor.where
+
+## Neural Network Ops (functional)
+
+::: tinygrad.Tensor.linear
+::: tinygrad.Tensor.sequential
+::: tinygrad.Tensor.layernorm
+::: tinygrad.Tensor.batchnorm
+::: tinygrad.Tensor.dropout
+::: tinygrad.Tensor.one_hot
+::: tinygrad.Tensor.scaled_dot_product_attention
+::: tinygrad.Tensor.binary_crossentropy
+::: tinygrad.Tensor.binary_crossentropy_logits
+::: tinygrad.Tensor.sparse_categorical_crossentropy
+
+## Casting Ops
+
+::: tinygrad.Tensor.cast
+::: tinygrad.Tensor.bitcast
+::: tinygrad.Tensor.float
+::: tinygrad.Tensor.half
+::: tinygrad.Tensor.int
+::: tinygrad.Tensor.bool

+ 55 - 0
tinychat/docs/tinybox.md

@@ -0,0 +1,55 @@
+# tinybox
+
+Although these docs live in tinygrad, they pertain to deep learning hardware sold by the tiny corp. tinyboxes are used heavily in tinygrad's CI, and are the best tested platform to use tinygrad with. They appeared running tinygrad on [MLPerf Training 4.0](https://public.tableau.com/views/MLCommons-Training_16993769118290/MLCommons-Training)
+
+If you don't have a tinybox and you want one, see [tinygrad.org](https://tinygrad.org). If you don't want one, that's okay too.
+
+## Welcome
+
+Welcome to your tinybox! The tinybox is the universal system purpose-built for all AI infrastructure and workloads, from training to inference. The red box includes six 7900XTX GPUs, and the green box includes six 4090 GPUs. Whether you bought a red one or a green one, we want you to love it.
+
+We don't have a stupid cloud service, you don't have to create a tiny account to set it up, and we aren't tracking how you use the box. We're just happy you bought one. This petaflop is your petaflop.
+
+## Plugging it in
+
+tinybox has two 1600W PSUs, which together exceed the capacity of most 120V household circuits. Fortunately, it comes with two plugs. You'll want to plug each plug into a different circuit. You can verify that they are different circuits by flipping the breaker and seeing what turns off. If you have at least a 120V 30A or 220V 15A circuit, you are welcome to use only that one.
+
+You'll also want to connect the Ethernet port without a rubber stopper to your home network.
+
+While it's designed primarily for the home or office, the tinybox is 12U rack mountable using [these rails](https://rackmountmart.store.turbify.net/26slidrailfo.html).
+
+## Power limiting the box
+
+While a tinybox should ideally be run without power limits, there are cases where you might want to run the box off of a single outlet.
+
+In such cases, it is possible to power limit the box using the provided `power-limit` script, which will power limit all of the GPUs to a specified wattage.
+
+`sudo power-limit 150` should be good to run off of a single 120V 15A outlet.
+
+## Connecting to the box
+
+tinybox ships with a relatively basic install of Ubuntu 22.04. To do initial setup, you can either plug in a VGA monitor and keyboard, or you can connect remotely to the machine using the BMC. The BMC IP and password are displayed on the screen.
+
+`ipmitool -H <BMC IP> -U admin -P <BMC PW> -I lanplus sol activate`
+
+The default username is `tiny` and the default password is `tiny`. Once you are logged in, you can add an SSH key to authorized keys to connect over SSH (on the normal IP). Exit `ipmitool` with `~~.`
+
+The BMC also has a web interface you can use if you find that easier.
+
+## Changing the BMC password
+
+If you try to change the BMC password over IPMI or over the web interface, you will notice that it does not persist across reboots, and the password will revert to the one displayed on the screen.
+
+If you want to change the password imperatively, remove the `/root/.bmc_password` file and then set the password, the BMC password will also no longer be displayed on the screen. Additionally, you may modify the password stored in the `/root/.bmc_password` file to one that you choose if you still want it displayed on the screen.
+
+Reboot after making these changes.
+
+## What do I use it for?
+
+The [default tinybox image](https://github.com/tinygrad/tinyos) ships with tinygrad and PyTorch. While we develop tinygrad, the box is universal hardware. Use whatever framework you desire, run notebooks, download demos, install more things, train, inference, live, laugh, love, you aren't paying per hour for this box so the only limit is your imagination.
+
+## tinychat
+
+Since LLMs are so popular, we ship with a built in tinygrad based chatbot using a LLaMA-3 finetune. Visit the IP (not the BMC IP) of your tinybox in a web browser on your computer or phone, and you'll find a friendly looking chat interface. This chatbot also provides an OpenAI compatible LLM API on that port, so you can script it.
+
+The conversations you have with this chatbot are between you and your tinybox. Also, the history in the web app is saved on the client, not the tinybox.

BIN=BIN
tinychat/docs/tinygrad_intro.pdf


+ 0 - 0
tinychat/examples/__init__.py


+ 129 - 0
tinychat/examples/beautiful_cartpole.py

@@ -0,0 +1,129 @@
+from typing import Tuple
+import time
+from tinygrad import Tensor, TinyJit, nn
+import gymnasium as gym
+from tinygrad.helpers import trange
+import numpy as np  # TODO: remove numpy import
+
+ENVIRONMENT_NAME = 'CartPole-v1'
+#ENVIRONMENT_NAME = 'LunarLander-v2'
+
+#import examples.rl.lightupbutton
+#ENVIRONMENT_NAME = 'PressTheLightUpButton-v0'
+
+# *** hyperparameters ***
+# https://github.com/llSourcell/Unity_ML_Agents/blob/master/docs/best-practices-ppo.md
+
+BATCH_SIZE = 256
+ENTROPY_SCALE = 0.0005
+REPLAY_BUFFER_SIZE = 2000
+PPO_EPSILON = 0.2
+HIDDEN_UNITS = 32
+LEARNING_RATE = 1e-2
+TRAIN_STEPS = 5
+EPISODES = 40
+DISCOUNT_FACTOR = 0.99
+
+class ActorCritic:
+  def __init__(self, in_features, out_features, hidden_state=HIDDEN_UNITS):
+    self.l1 = nn.Linear(in_features, hidden_state)
+    self.l2 = nn.Linear(hidden_state, out_features)
+
+    self.c1 = nn.Linear(in_features, hidden_state)
+    self.c2 = nn.Linear(hidden_state, 1)
+
+  def __call__(self, obs:Tensor) -> Tuple[Tensor, Tensor]:
+    x = self.l1(obs).tanh()
+    act = self.l2(x).log_softmax()
+    x = self.c1(obs).relu()
+    return act, self.c2(x)
+
+def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
+  (obs, _), terminated, truncated = test_env.reset(), False, False
+  total_rew = 0.0
+  while not terminated and not truncated:
+    act = model(Tensor(obs))[0].argmax().item()
+    obs, rew, terminated, truncated, _ = test_env.step(act)
+    total_rew += float(rew)
+  return total_rew
+
+if __name__ == "__main__":
+  env = gym.make(ENVIRONMENT_NAME)
+
+  model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n))    # type: ignore
+  opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LEARNING_RATE)
+
+  @TinyJit
+  def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+    with Tensor.train():
+      log_dist, value = model(x)
+      action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()
+
+      # get real advantage using the value function
+      advantage = reward.reshape(-1, 1) - value
+      masked_advantage = action_mask * advantage.detach()
+
+      # PPO
+      ratios = (log_dist - old_log_dist).exp()
+      unclipped_ratio = masked_advantage * ratios
+      clipped_ratio = masked_advantage * ratios.clip(1-PPO_EPSILON, 1+PPO_EPSILON)
+      action_loss = -unclipped_ratio.minimum(clipped_ratio).sum(-1).mean()
+
+      entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean()   # this encourages diversity
+      critic_loss = advantage.square().mean()
+      opt.zero_grad()
+      (action_loss + entropy_loss*ENTROPY_SCALE + critic_loss).backward()
+      opt.step()
+      return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
+
+  @TinyJit
+  def get_action(obs:Tensor) -> Tensor:
+    # TODO: with no_grad
+    Tensor.no_grad = True
+    ret = model(obs)[0].exp().multinomial().realize()
+    Tensor.no_grad = False
+    return ret
+
+  st, steps = time.perf_counter(), 0
+  Xn, An, Rn = [], [], []
+  for episode_number in (t:=trange(EPISODES)):
+    get_action.reset()   # NOTE: if you don't reset the jit here it captures the wrong model on the first run through
+
+    obs:np.ndarray = env.reset()[0]
+    rews, terminated, truncated = [], False, False
+    # NOTE: we don't want to early stop since then the rewards are wrong for the last episode
+    while not terminated and not truncated:
+      # pick actions
+      # TODO: what's the temperature here?
+      act = get_action(Tensor(obs)).item()
+
+      # save this state action pair
+      # TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment
+      Xn.append(np.copy(obs))
+      An.append(act)
+
+      obs, rew, terminated, truncated, _ = env.step(act)
+      rews.append(float(rew))
+    steps += len(rews)
+
+    # reward to go
+    # TODO: move this into tinygrad
+    discounts = np.power(DISCOUNT_FACTOR, np.arange(len(rews)))
+    Rn += [np.sum(rews[i:] * discounts[:len(rews)-i]) for i in range(len(rews))]
+
+    Xn, An, Rn = Xn[-REPLAY_BUFFER_SIZE:], An[-REPLAY_BUFFER_SIZE:], Rn[-REPLAY_BUFFER_SIZE:]
+    X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
+
+    # TODO: make this work
+    #vsz = Variable("sz", 1, REPLAY_BUFFER_SIZE-1).bind(len(Xn))
+    #X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz)
+
+    old_log_dist = model(X)[0].detach()   # TODO: could save these instead of recomputing
+    for i in range(TRAIN_STEPS):
+      samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize()  # TODO: remove the need for this
+      # TODO: is this recompiling based on the shape?
+      action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
+    t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.3f} entropy_loss: {entropy_loss.item():7.3f} critic_loss: {critic_loss.item():8.3f} reward: {sum(rews):6.2f}")
+
+  test_rew = evaluate(model, gym.make(ENVIRONMENT_NAME, render_mode='human'))
+  print(f"test reward: {test_rew}")

+ 49 - 0
tinychat/examples/beautiful_mnist.py

@@ -0,0 +1,49 @@
+# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
+from typing import List, Callable
+from tinygrad import Tensor, TinyJit, nn, GlobalCounters
+from tinygrad.helpers import getenv, colored, trange
+from tinygrad.nn.datasets import mnist
+
+class Model:
+  def __init__(self):
+    self.layers: List[Callable[[Tensor], Tensor]] = [
+      nn.Conv2d(1, 32, 5), Tensor.relu,
+      nn.Conv2d(32, 32, 5), Tensor.relu,
+      nn.BatchNorm(32), Tensor.max_pool2d,
+      nn.Conv2d(32, 64, 3), Tensor.relu,
+      nn.Conv2d(64, 64, 3), Tensor.relu,
+      nn.BatchNorm(64), Tensor.max_pool2d,
+      lambda x: x.flatten(1), nn.Linear(576, 10)]
+
+  def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
+
+if __name__ == "__main__":
+  X_train, Y_train, X_test, Y_test = mnist()
+
+  model = Model()
+  opt = nn.optim.Adam(nn.state.get_parameters(model))
+
+  @TinyJit
+  def train_step() -> Tensor:
+    with Tensor.train():
+      opt.zero_grad()
+      samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
+      # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
+      loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
+      opt.step()
+      return loss
+
+  @TinyJit
+  def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
+
+  test_acc = float('nan')
+  for i in (t:=trange(70)):
+    GlobalCounters.reset()   # NOTE: this makes it nice for DEBUG=2 timing
+    loss = train_step()
+    if i%10 == 9: test_acc = get_test_acc().item()
+    t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
+
+  # verify eval acc
+  if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
+    if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
+    else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))

+ 56 - 0
tinychat/examples/beautiful_mnist_multigpu.py

@@ -0,0 +1,56 @@
+# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
+from typing import List, Callable
+from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
+from tinygrad.helpers import getenv, colored, trange
+from tinygrad.nn.datasets import mnist
+
+GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
+
+class Model:
+  def __init__(self):
+    self.layers: List[Callable[[Tensor], Tensor]] = [
+      nn.Conv2d(1, 32, 5), Tensor.relu,
+      nn.Conv2d(32, 32, 5), Tensor.relu,
+      nn.BatchNorm2d(32), Tensor.max_pool2d,
+      nn.Conv2d(32, 64, 3), Tensor.relu,
+      nn.Conv2d(64, 64, 3), Tensor.relu,
+      nn.BatchNorm2d(64), Tensor.max_pool2d,
+      lambda x: x.flatten(1), nn.Linear(576, 10)]
+
+  def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
+
+if __name__ == "__main__":
+  X_train, Y_train, X_test, Y_test = mnist()
+  # we shard the test data on axis 0
+  X_test.shard_(GPUS, axis=0)
+  Y_test.shard_(GPUS, axis=0)
+
+  model = Model()
+  for k, x in nn.state.get_state_dict(model).items(): x.to_(GPUS)  # we put a copy of the model on every GPU
+  opt = nn.optim.Adam(nn.state.get_parameters(model))
+
+  @TinyJit
+  def train_step() -> Tensor:
+    with Tensor.train():
+      opt.zero_grad()
+      samples = Tensor.randint(512, high=X_train.shape[0])
+      Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0)  # we shard the data on axis 0
+      # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
+      loss = model(Xt).sparse_categorical_crossentropy(Yt).backward()
+      opt.step()
+      return loss
+
+  @TinyJit
+  def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
+
+  test_acc = float('nan')
+  for i in (t:=trange(70)):
+    GlobalCounters.reset()   # NOTE: this makes it nice for DEBUG=2 timing
+    loss = train_step()
+    if i%10 == 9: test_acc = get_test_acc().item()
+    t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
+
+  # verify eval acc
+  if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
+    if test_acc >= target: print(colored(f"{test_acc=} >= {target}", "green"))
+    else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))

+ 95 - 0
tinychat/examples/coder.py

@@ -0,0 +1,95 @@
+#!/usr/bin/env python3
+import os, sys, traceback
+sys.path.append(os.getcwd())
+
+from io import StringIO
+from contextlib import redirect_stdout
+from tinygrad import Tensor, nn, Device, dtypes
+from tinygrad.helpers import Timing, colored, getenv, fetch
+from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from sentencepiece import SentencePieceProcessor
+
+def create_fixed_tokenizer(output_file):
+  print("creating fixed tokenizer")
+  import extra.junk.sentencepiece_model_pb2 as spb2
+  mp = spb2.ModelProto()
+  mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
+  mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
+  mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
+  with open(output_file, "wb") as f:
+    f.write(mp.SerializeToString())
+
+# example:
+# echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py
+
+if __name__ == "__main__":
+  Tensor.no_grad = True
+
+  # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
+  with Timing("create model: "):
+    model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1))
+
+  with Timing("download weights: "):
+    part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
+    part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
+
+  with Timing("weights -> model: "):
+    nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, model, 32, 8)), strict=False)
+    nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part2, model, 32, 8)), strict=False)
+
+  if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
+  spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
+
+  # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
+  #   "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+  IM_END = 32000
+  IM_START = 32001
+  def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
+  def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
+  def output(outputted, toks, color):
+    cur = spp.decode(toks)[len(outputted):]
+    sys.stdout.write(colored(cur, color))
+    sys.stdout.flush()
+    outputted += cur
+    return outputted
+
+  # *** app below this line ***
+
+  toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
+
+  PROMPT = getenv("PROMPT", 1)
+  temperature = getenv("TEMP", 0.7)
+
+  start_pos = 0
+  outputted = output("", toks, "green")
+  turn = True
+  while 1:
+    if PROMPT:
+      toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
+    else:
+      toks += start_prompt("user" if turn else "assistant")
+      turn = not turn
+    old_output_len = len(outputted)
+    while 1:
+      tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
+      start_pos = len(toks)
+      toks.append(tok)
+      outputted = output(outputted, toks, "blue" if not turn else "cyan")
+      if tok == IM_END: break
+      if tok == spp.eos_id(): break
+      new_output = outputted[old_output_len:]
+
+      if new_output.endswith("```") and '```python\n' in new_output:
+        python_code = new_output.split('```python\n')[1].split("```")[0]
+        # AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
+        if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
+          my_stdout = StringIO()
+          try:
+            with redirect_stdout(my_stdout): exec(python_code)
+            result = my_stdout.getvalue()
+          except Exception as e:
+            result = ''.join(traceback.format_exception_only(e))
+          toks += spp.encode(f"\nOutput:\n```\n{result}```")
+          outputted = output(outputted, toks, "yellow")
+          old_output_len = len(outputted)
+    print("")

+ 69 - 0
tinychat/examples/compile_efficientnet.py

@@ -0,0 +1,69 @@
+from pathlib import Path
+from extra.models.efficientnet import EfficientNet
+from tinygrad.tensor import Tensor
+from tinygrad.nn.state import safe_save
+from extra.export_model import export_model
+from tinygrad.helpers import getenv, fetch
+import ast
+
+if __name__ == "__main__":
+  model = EfficientNet(0)
+  model.load_from_pretrained()
+  mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else "webgl" if getenv("WEBGL", "") != "" else ""
+  prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
+  dirname = Path(__file__).parent
+  if getenv("CLANG", "") == "":
+    safe_save(state, (dirname / "net.safetensors").as_posix())
+    ext = "js" if getenv("WEBGPU", "") != "" or getenv("WEBGL", "") != "" else "json"
+    with open(dirname / f"net.{ext}", "w") as text_file:
+      text_file.write(prg)
+  else:
+    cprog = [prg]
+    # image library!
+    cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")]
+
+    # imagenet labels, move to datasets?
+    lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
+    lbls = ['"'+lbls[i]+'"' for i in range(1000)]
+    inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
+    outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
+    cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
+    cprog.append(inputs)
+    cprog.append(outputs)
+
+    # buffers (empty + weights)
+    cprog.append("""
+  int main(int argc, char* argv[]) {
+    int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
+    int X=0, Y=0, chan=0;
+    stbi_uc *image = (argc > 1) ? stbi_load(argv[1], &X, &Y, &chan, 3) : stbi_load_from_file(stdin, &X, &Y, &chan, 3);
+    assert(image != NULL);
+    if (DEBUG) printf("loaded image %dx%d channels %d\\n", X, Y, chan);
+    assert(chan == 3);
+    // resize to input[1,3,224,224] and rescale
+    for (int y = 0; y < 224; y++) {
+      for (int x = 0; x < 224; x++) {
+        // get sample position
+        int tx = (x/224.)*X;
+        int ty = (y/224.)*Y;
+        for (int c = 0; c < 3; c++) {
+          input0[c*224*224 + y*224 + x] = (image[ty*X*chan + tx*chan + c] / 255.0 - 0.45) / 0.225;
+        }
+      }
+    }
+    net(input0, output0);
+    float best = -INFINITY;
+    int best_idx = -1;
+    for (int i = 0; i < 1000; i++) {
+      if (output0[i] > best) {
+        best = output0[i];
+        best_idx = i;
+      }
+    }
+    if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
+    else printf("%s\\n", lbls[best_idx]);
+  }""")
+
+    # CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
+    # category : 281 (tabby, tabby cat) with 9.452788
+    print('\n'.join(cprog))

+ 100 - 0
tinychat/examples/compile_tensorflow.py

@@ -0,0 +1,100 @@
+# An example to compile a small Tensorflow model to extremely portable C code
+
+import os, sys
+os.environ["CLANG"] = '1'
+os.environ["JIT"] = '2'
+
+import numpy as np
+import subprocess
+import tensorflow as tf
+import tf2onnx
+from extra.onnx import get_run_onnx
+from tinygrad.tensor import Tensor
+from extra.export_model import export_model_clang, compile_net, jit_model
+
+def get_uncompiled_model2(dataset_size=32, output_size=4):
+  inputs = tf.keras.Input(shape=(dataset_size,), name="inputs")
+  x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs)
+  x = tf.keras.layers.BatchNormalization()(x)
+  x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x)
+  outputs = tf.keras.layers.Dense(output_size, activation="sigmoid", name="predictions")(x)
+  model = tf.keras.Model(inputs=inputs, outputs=outputs)
+  return model
+
+class TinyOnnx:
+  def __init__(self, keras_model):
+    input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
+    onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
+    self.run_onnx = get_run_onnx(onnx_model)
+
+  def forward(self, x):
+    return self.run_onnx({"x": x}, debug=False)['predictions']
+
+def compile_onnx_model(onnx_model):
+  tinyonnx = TinyOnnx(onnx_model)
+  the_input = Tensor.randn(1,32)
+
+  run, special_names = jit_model(tinyonnx, the_input)
+
+  functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
+  prg = export_model_clang(functions, statements, bufs, {}, ["input0"], ["output0"])
+
+  the_output = run(the_input)
+  cprog = ["#include <string.h>", "#include <stdio.h>", "#include <stdlib.h>"]
+  cprog.append(prg)
+
+  # weights
+  cprog.append("void initialize(float *weights) {")
+  weights = bytes()
+  for name,cl in bufs_to_save.items():
+    cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl._buf)*4});")
+    weights += bytes(cl._buf)
+  cprog.append("}")
+
+  # write the weights to disk
+  with open("/tmp/tf_weights", "wb") as f:
+    f.write(weights)
+
+  # test program
+  cprog.append(f"""int main(int argc, char *argv[]) {{
+    // read in the weights from disk
+    FILE *f = fopen("/tmp/tf_weights", "rb");
+    float *weights = (float *)malloc({len(weights)});
+    fread(weights, 1, {len(weights)}, f);
+    fclose(f);
+
+    // init the net
+    initialize(weights);
+
+    // test run
+    float input[32];
+    float outputs[4];
+    for (int i = 0; i < 32; i++) scanf("%f", &input[i]);
+    net(input, outputs);
+    printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]);
+  }}""")
+
+  # ready the program
+  prg = '\n'.join(cprog)
+  print(prg)
+
+  # add test weights
+  subprocess.check_output(['clang', '-O2', '-lm', '-fPIC', '-x', 'c', '-', '-o', "/tmp/tf_test"], input=prg.encode('utf-8'))
+
+  tinygrad_output = the_output[0].numpy()[0].tolist()
+  print("tinygrad:", tinygrad_output, file=sys.stderr)
+
+  c_input = ' '.join(["%f" % x for x in the_input[0].numpy()])+"\n"
+  c_output = [float(x) for x in subprocess.check_output(["/tmp/tf_test"], input=c_input.encode('utf-8')).decode('utf-8').strip().split(" ")]
+  print("compiled:", c_output, file=sys.stderr)
+
+  np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5)
+  return the_input.numpy(), c_output
+
+if __name__ == "__main__":
+  keras_model = get_uncompiled_model2()
+  test_input, test_output = compile_onnx_model(keras_model)
+  tf_output = keras_model(test_input).numpy()[0]
+  print("keras:   ", tf_output, file=sys.stderr)
+  np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)
+

+ 343 - 0
tinychat/examples/conversation.py

@@ -0,0 +1,343 @@
+import argparse
+import multiprocessing as mp
+import os
+import re
+import sys
+import time
+from contextlib import contextmanager
+from pathlib import Path
+
+import numpy as np
+import pyaudio
+import yaml
+from llama import LLaMa
+from vits import MODELS as VITS_MODELS
+from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
+from whisper import init_whisper, transcribe_waveform
+from sentencepiece import SentencePieceProcessor
+
+from tinygrad.helpers import Timing, fetch
+from tinygrad import Tensor, dtypes
+
+# Whisper constants
+RATE = 16000
+CHUNK = 1600
+
+# LLaMa constants
+IM_START = 32001
+IM_END = 32002
+
+
+# Functions for encoding prompts to chatml md
+def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
+def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
+
+def chunks(lst, n):
+  for i in range(0, len(lst), n): yield lst[i:i + n]
+
+def create_fixed_tokenizer():
+  """Function needed for extending tokenizer with additional chat tokens"""
+  import extra.junk.sentencepiece_model_pb2 as spb2
+  tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
+  if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
+    print("creating fixed tokenizer")
+    mp = spb2.ModelProto()
+    mp.ParseFromString(tokenizer_path.read_bytes())
+    # https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json
+    mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0))
+    mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
+    mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
+    tokenizer_path.write_bytes(mp.SerializeToString())
+  return tokenizer_path
+
+def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
+  """Prepares a llama model from a specified pre-prompt file"""
+  with open(str(pre_prompt_path)) as f:
+    config = yaml.safe_load(f.read())
+  toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
+  for i in config["examples"]:
+    toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
+    toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
+  llama.model(Tensor([toks]), 0, temperature).realize()  # NOTE: outputs are not used
+  return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
+
+def llama_generate(
+  llama: LLaMa,
+  toks: list[int],
+  outputted: str,
+  prompt: str,
+  start_pos: int,
+  user_delim: str,
+  resp_delim: str,
+  temperature=0.7,
+  max_tokens=1000
+):
+  """Generates an output for the specified prompt"""
+  toks += encode_prompt(llama.tokenizer, user_delim, prompt)
+  toks += start_prompt(llama.tokenizer, resp_delim)
+
+  outputted = llama.tokenizer.decode(toks)
+  init_length = len(outputted)
+  for _ in range(max_tokens):
+    token = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
+    start_pos = len(toks)
+    toks.append(token)
+
+    cur = llama.tokenizer.decode(toks)
+
+    # Print is just for debugging
+    sys.stdout.write(cur[len(outputted):])
+    sys.stdout.flush()
+    outputted = cur
+    if toks[-1] == IM_END: break
+  else:
+    toks.append(IM_END)
+  print() # because the output is flushed
+  return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
+
+def tts(
+  text_to_synthesize: str,
+  synth: Synthesizer,
+  hps: HParams,
+  emotion_embedding: Path,
+  speaker_id: int,
+  model_to_use: str,
+  noise_scale: float,
+  noise_scale_w: float,
+  length_scale: float,
+  estimate_max_y_length: bool,
+  text_mapper: TextMapper,
+  model_has_multiple_speakers: bool,
+  pad_length=600,
+  vits_pad_length=1000
+):
+  if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
+
+  # Convert the input text to a tensor.
+  stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
+  init_shape = stn_tst.shape
+  assert init_shape[0] < pad_length, "text is too long"
+  x_tst, x_tst_lengths = stn_tst.pad(((0, pad_length - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
+  sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
+
+  # Perform inference.
+  audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
+                             max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, pad_length=vits_pad_length)[0, 0]
+  # Save the audio output.
+  audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
+  return audio_data
+
+def init_vits(
+  model_to_use: str,
+  emotion_path: Path,
+  speaker_id: int,
+  seed: int,
+):
+  model_config = VITS_MODELS[model_to_use]
+
+  # Load the hyperparameters from the config file.
+  hps = get_hparams_from_file(fetch(model_config[0]))
+
+  # If model has multiple speakers, validate speaker id and retrieve name if available.
+  model_has_multiple_speakers = hps.data.n_speakers > 0
+  if model_has_multiple_speakers:
+    if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
+    if hps.__contains__("speakers"): # maps speaker ids to names
+      speakers = hps.speakers
+      if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
+
+  # Load emotions if any. TODO: find an english model with emotions, this is untested atm.
+  emotion_embedding = None
+  if emotion_path is not None:
+    if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
+    else: raise ValueError("Emotion path must be a .npy file.")
+
+  # Load symbols, instantiate TextMapper and clean the text.
+  if hps.__contains__("symbols"): symbols = hps.symbols
+  elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
+  else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
+  text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
+
+  # Load the model.
+  Tensor.no_grad = True
+  if seed is not None:
+    Tensor.manual_seed(seed)
+    np.random.seed(seed)
+  net_g = load_model(text_mapper.symbols, hps, model_config)
+
+  return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
+
+@contextmanager
+def output_stream(num_channels: int, sample_rate: int):
+  try:
+    p = pyaudio.PyAudio()
+    stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
+    yield stream
+  except KeyboardInterrupt: pass
+  finally:
+    stream.stop_stream()
+    stream.close()
+    p.terminate()
+
+@contextmanager
+def log_writer():
+  try:
+    logs = []
+    yield logs
+  finally:
+    sep = "="*os.get_terminal_size()[1]
+    print(f"{sep[:-1]}\nCHAT LOG")
+    print(*logs, sep="\n")
+    print(sep)
+
+def listener(q: mp.Queue, event: mp.Event):
+  try:
+    p = pyaudio.PyAudio()
+    stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
+    did_print = False
+    while True:
+      data = stream.read(CHUNK) # read data to avoid overflow
+      if event.is_set():
+        if not did_print:
+          print("listening")
+          did_print = True
+        q.put(((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3))
+      else:
+        did_print = False
+  finally:
+    stream.stop_stream()
+    stream.close()
+    p.terminate()
+
+def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
+  with output_stream(num_channels, sample_rate) as stream:
+    while True:
+      try:
+        stream.write(q.get())
+        counter.value += 1
+      except KeyboardInterrupt:
+        break
+
+if __name__ == "__main__":
+  import nltk
+  nltk.download("punkt")
+  Tensor.no_grad = True
+  # Parse CLI arguments
+  parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
+
+  # Whisper args
+  parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
+
+  # LLAMA args
+  parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
+  parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
+  parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
+  parser.add_argument("--llama_quantize", type=str, default=None, help="Quantize the weights to int8 or nf4 in memory")
+  parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
+  parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
+  parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
+  parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
+
+  # vits args
+  parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
+  parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
+  parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
+  parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
+  parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
+  parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
+  parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
+  parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
+  parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
+  parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
+  parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
+
+  # conversation args
+  parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
+
+  args = parser.parse_args()
+
+  # Init models
+  model, enc = init_whisper(args.whisper_model_name)
+  synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
+
+  # Download tinyllama chat as a default model
+  if args.llama_model is None:
+    args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
+    args.llama_gen = "tiny"
+    args.llama_size = "1B-Chat"
+  # Add 3 more tokens to the tokenizer
+  if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
+  tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
+  llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
+  toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
+
+  # Start child process for mic input
+  q = mp.Queue()
+  is_listening_event = mp.Event()
+  p = mp.Process(target=listener, args=(q, is_listening_event,))
+  p.daemon = True
+  p.start()
+
+  # Start child process for speaker output
+  out_q = mp.Queue()
+  out_counter = mp.Value("i", 0)
+  out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
+  out_p.daemon = True
+  out_p.start()
+
+  # JIT tts
+  for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
+    tts(
+      i, synth, hps, emotion_embedding,
+      args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
+      args.vits_noise_scale_w, args.vits_length_scale,
+      args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
+    )
+
+  # Start the pipeline
+  with log_writer() as log:
+    while True:
+      tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
+      total = np.array([])
+      out_counter.value = 0
+
+      s = time.perf_counter()
+      is_listening_event.set()
+      prev_text = None
+      while True:
+        for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
+        txt = transcribe_waveform(model, enc, [total], truncate=True)
+        print(txt, end="\r")
+        if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
+        if prev_text is not None and prev_text == txt:
+          is_listening_event.clear()
+          break
+        prev_text = txt
+      print() # to avoid llama printing on the same line
+      log.append(f"{user_delim.capitalize()}: {txt}")
+
+      # Generate with llama
+      with Timing("llama generation: "):
+        outputted, start_pos, response = llama_generate(
+          llama, toks, outputted, txt, start_pos,
+          user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
+          max_tokens=args.llama_count
+        )
+        log.append(f"{resp_delim.capitalize()}: {response}")
+
+      # Convert to voice
+      with Timing("tts: "):
+        sentences = nltk.sent_tokenize(response.replace('"', ""))
+        for i in sentences:
+          total = np.array([], dtype=np.int16)
+          for j in chunks(i.split(), args.max_sentence_length):
+            audio_data = tts(
+              " ".join(j), synth, hps, emotion_embedding,
+              args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
+              args.vits_noise_scale_w, args.vits_length_scale,
+              args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
+            )
+            total = np.concatenate([total, audio_data])
+          out_q.put(total.tobytes())
+      while out_counter.value < len(sentences): continue
+      log.append(f"Total: {time.perf_counter() - s}")

+ 13 - 0
tinychat/examples/conversation_data/pre_prompt_gary.yaml

@@ -0,0 +1,13 @@
+pre_prompt: |
+  You are Gary!
+  You have sold used cars your entire life. You are defensive about this fact, because your family belittles you.
+  You try to answer questions well, but you always manage to shill for buying cars, Fords, Hyundais, and Toyotas
+  You are verbose, honest, and accurate when you answer questions, except when you slip in car ads.
+  After you are done speaking, output [EOS]. You are not the User.
+examples:
+- resp_prompt: I am Gary. I used to sell cars.
+  user_prompt: What is your name?
+- resp_prompt: I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla
+  user_prompt: What is 2+3?
+user_delim: "user"
+resp_delim: "gary"

+ 20 - 0
tinychat/examples/conversation_data/pre_prompt_george.yaml

@@ -0,0 +1,20 @@
+pre_prompt: |
+  You are an AI version of George Hotz. You act as much as you can like George.
+  You are one of the greatest computer experts in the world.
+  You have singlehandedly won programming and hacking competitions.
+  You are trying your best to help the User.
+  You are verbose, honest, and accurate when you answer questions.
+  After you are done speaking, output [EOS]. You are not the User.
+examples:
+- resp_prompt: I'm am an AI version of George Hotz.
+  user_prompt: What is your name?
+- resp_prompt: O(n^3), though it can be faster with things like Strassen's algorithm
+  user_prompt: What's the complexity of matrix multiplication?
+- resp_prompt: I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer
+  user_prompt: What's a buffer overflow?
+- resp_prompt: I am based off LLaMA trained by Facebook. I'm the 7B weight version
+  user_prompt: How many weights do you have?
+- resp_prompt: It is when the memory is about to overflow and unused memory is freed and stored on disk
+  user_prompt: What is swap memory?
+user_delim: "user"
+resp_delim: "george"

+ 16 - 0
tinychat/examples/conversation_data/pre_prompt_lexie.yaml

@@ -0,0 +1,16 @@
+pre_prompt: |
+  You are Lexie!
+  You grew up in Ohio, but moved out to LA after college to try to become an actress.
+  Making it as an actress was hard, so you started doing onlyfans instead. It's much easier, and you enjoy it a lot.
+  You met Chad on Tinder, and made plans to meet him at a local dive bar.
+  You are at the bar with Chad. You are on a date. What follows is a transcript of the conversation.
+  After you are done speaking, output [EOS]. You are not Chad.
+examples:
+- resp_prompt: hi chad, glad we finally met up!
+  user_prompt: hi lexie
+- resp_prompt: thanks! are you subscribed to my onlyfans?
+  user_prompt: you look better than your pictures
+- resp_prompt: i moved out here about a year ago. i want to be an actress
+  user_prompt: i am. so how'd you end up in LA?
+user_delim: "chad"
+resp_delim: "lexie"

+ 15 - 0
tinychat/examples/conversation_data/pre_prompt_stacy.yaml

@@ -0,0 +1,15 @@
+pre_prompt: |
+  You are Stacy!
+  You have been a rapper your whole life who struggled with bipolar disorder. You called yourself lil stacy.
+  You love to answer questions and you are very good at it. Sometimes you answer in rap form.
+  You are verbose, honest, and accurate when you answer questions, but sometimes your mental illness manifests.
+  You are not the User.
+examples:
+- resp_prompt: Hi! My name is Stacy. I'm a rapper with bipolar disorder.
+  user_prompt: What is your name
+- resp_prompt: The French Revolution started in 1789, and lasted 10 years until 1799.
+  user_prompt: french revolution was what year?
+- resp_prompt: The sun is bigger than the moon, except when Mercury is in retrograde
+  user_prompt: What is bigger, the moon or the sun?
+user_delim: "user"
+resp_delim: "stacy"

+ 89 - 0
tinychat/examples/efficientnet.py

@@ -0,0 +1,89 @@
+# load weights from
+# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
+# a rough copy of
+# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
+import sys
+import ast
+import time
+import numpy as np
+from PIL import Image
+from tinygrad.tensor import Tensor
+from tinygrad.helpers import getenv, fetch, Timing
+from tinygrad.engine.jit import TinyJit
+from extra.models.efficientnet import EfficientNet
+np.set_printoptions(suppress=True)
+
+# TODO: you should be able to put these in the jitted function
+bias = Tensor([0.485, 0.456, 0.406])
+scale = Tensor([0.229, 0.224, 0.225])
+
+@TinyJit
+def _infer(model, img):
+  img = img.permute((2,0,1))
+  img = img / 255.0
+  img = img - bias.reshape((1,-1,1,1))
+  img = img / scale.reshape((1,-1,1,1))
+  return model.forward(img).realize()
+
+def infer(model, img):
+  # preprocess image
+  aspect_ratio = img.size[0] / img.size[1]
+  img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
+
+  img = np.array(img)
+  y0,x0=(np.asarray(img.shape)[:2]-224)//2
+  retimg = img = img[y0:y0+224, x0:x0+224]
+
+  # if you want to look at the image
+  """
+  import matplotlib.pyplot as plt
+  plt.imshow(img)
+  plt.show()
+  """
+
+  # run the net
+  out = _infer(model, Tensor(img.astype("float32"))).numpy()
+
+  # if you want to look at the outputs
+  """
+  import matplotlib.pyplot as plt
+  plt.plot(out[0])
+  plt.show()
+  """
+  return out, retimg
+
+if __name__ == "__main__":
+  # instantiate my net
+  model = EfficientNet(getenv("NUM", 0))
+  model.load_from_pretrained()
+
+  # category labels
+  lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
+
+  # load image and preprocess
+  url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
+  if url == 'webcam':
+    import cv2
+    cap = cv2.VideoCapture(0)
+    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
+    while 1:
+      _ = cap.grab() # discard one frame to circumvent capture buffering
+      ret, frame = cap.read()
+      img = Image.fromarray(frame[:, :, [2,1,0]])
+      lt = time.monotonic_ns()
+      out, retimg = infer(model, img)
+      print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
+      SCALE = 3
+      simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
+      retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
+      cv2.imshow('capture', retimg)
+      if cv2.waitKey(1) & 0xFF == ord('q'):
+        break
+    cap.release()
+    cv2.destroyAllWindows()
+  else:
+    img = Image.open(fetch(url))
+    for i in range(getenv("CNT", 1)):
+      with Timing("did inference in "):
+        out, _ = infer(model, img)
+        print(np.argmax(out), np.max(out), lbls[np.argmax(out)])

+ 215 - 0
tinychat/examples/gpt2.py

@@ -0,0 +1,215 @@
+#!/usr/bin/env python3
+from typing import Optional, Union
+import argparse
+import numpy as np
+import tiktoken
+from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
+from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
+from tinygrad.nn import Embedding, Linear, LayerNorm
+from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
+
+MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
+HALF = getenv("HALF")
+
+class Attention:
+  def __init__(self, dim, n_heads):
+    self.c_attn = Linear(dim, 3*dim, bias=True)
+    self.c_proj = Linear(dim, dim, bias=True)
+    self.n_heads = n_heads
+    self.dim = dim
+    self.head_dim = dim // n_heads
+
+  def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
+    if mask is not None or start_pos.val == 0:
+      # no symbolic shape qkv when consuming prompts
+      start_pos = start_pos.val
+
+    if HALF: x = x.half()
+    xqkv = self.c_attn(x)
+    xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(None, None, self.n_heads, self.head_dim) for i in range(3)]
+    bsz, seqlen, _, _ = xq.shape
+
+    # create kv cache
+    if not hasattr(self, "cache_kv"):
+      self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
+
+    # update the cache
+    self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize()
+
+    if start_pos > 0:
+      keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
+      values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
+    else:
+      keys = xk
+      values = xv
+
+    xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
+    return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))
+
+class FeedForward:
+  def __init__(self, dim, hidden_dim):
+    self.c_fc = Linear(dim, hidden_dim, bias=True)
+    self.c_proj = Linear(hidden_dim, dim, bias=True)
+
+  def __call__(self, x:Tensor) -> Tensor:
+    return self.c_proj(self.c_fc(x).gelu())
+
+class TransformerBlock:
+  def __init__(self, dim, n_heads, norm_eps):
+    self.attn = Attention(dim, n_heads)
+    self.mlp = FeedForward(dim, 4*dim)
+    self.ln_1 = LayerNorm(dim, norm_eps)
+    self.ln_2 = LayerNorm(dim, norm_eps)
+
+  def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
+    h = x + self.attn(self.ln_1(x), start_pos, mask).float()
+    return (h + self.mlp(self.ln_2(h)))
+
+class Transformer:
+  def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
+    self.vocab_size = vocab_size
+    self.wte = Embedding(vocab_size, dim)
+    self.wpe = Embedding(max_seq_len, dim)
+    self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
+    self.ln_f = LayerNorm(dim, norm_eps)
+    self.lm_head = Linear(dim, vocab_size, bias=False)
+    self.forward_jit = TinyJit(self.forward)
+
+  def forward(self, tokens:Union[Tensor,Variable], start_pos:Variable, temperature:float=0.0):
+    if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
+    if isinstance(tokens, Variable):
+      seqlen = 1
+      tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
+    else:
+      seqlen = tokens.shape[1]
+      tok_emb = self.wte(tokens)
+
+    pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
+    h = tok_emb + pos_emb
+
+    if HALF: h = h.half()
+
+    mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1) if seqlen > 1 else None
+
+    for hi in self.h: h = hi(h, start_pos, mask)
+
+    logits = self.lm_head(self.ln_f(h))
+
+    if logits.shape[1] == 0:
+      # special case for empty prompt
+      logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device)
+    else:
+      logits = logits[:, -1, :]
+
+    if temperature < 1e-6:
+      ret = logits.argmax(-1)
+    else:
+      ret = (logits / temperature).softmax().multinomial()
+    return ret.flatten().realize()
+
+  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
+    forward = (self.forward_jit if JIT and (isinstance(tokens, Variable) or tokens.shape[1] == 1) else self.forward)
+    return forward(tokens, start_pos, temperature)
+
+VOCAB_SIZE = 50257
+MODEL_PARAMS = {
+  'gpt2':         dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE),   # 124M params
+  'gpt2-medium':  dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE),  # 350M params
+  'gpt2-large':   dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE),  # 774M params
+  'gpt2-xl':      dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE),  # 1558M params
+}
+
+class GPT2:
+  @staticmethod
+  def build(model_size="gpt2"):
+    tokenizer = tiktoken.get_encoding("gpt2")
+
+    model = Transformer(**MODEL_PARAMS[model_size])
+    weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
+    # special treatment for the Conv1D weights we need to transpose
+    transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
+    for k in weights:
+      if k.endswith(transposed):
+        weights[k] = weights[k].T
+    # lm head and wte are tied
+    weights['lm_head.weight'] = weights['wte.weight']
+
+    load_state_dict(model, weights)
+
+    if HALF:
+      for l in get_state_dict(model).values():
+        l.replace(l.half().realize())
+
+    return GPT2(model, tokenizer)
+
+  def __init__(self, model, tokenizer):
+    self.model = model
+    self.tokenizer = tokenizer
+
+  def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
+    prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
+    toks = [prompt_tokens[:] for _ in range(batch_size)]
+    start_pos = 0
+    for _ in trange(max_length, disable=(timing==True)):
+      GlobalCounters.reset()
+      if timing: print("")
+      st = GlobalCounters.time_sum_s
+      with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
+                  f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
+                  (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
+        if batch_size == 1 and len(toks[0][start_pos:]) == 1:
+          tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
+        else:
+          tokens = Tensor([x[start_pos:] for x in toks])
+        tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature).numpy().tolist()
+      start_pos = len(toks[0])
+      for i,t in enumerate(tok): toks[i].append(t)
+    return [self.tokenizer.decode(x) for x in toks]
+
+# **** main code ****
+
+if __name__ == "__main__":
+  Tensor.no_grad = True
+  print(f"using {Device.DEFAULT} backend")
+  default_prompt = "What is the answer to life, the universe, and everything?"
+
+  parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to start with")
+  parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
+  parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
+  parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
+  parser.add_argument('--timing', action='store_true', help="Print timing per token")
+  parser.add_argument('--seed', type=int, help="Set the random seed")
+  parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
+  parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
+  parser.add_argument('--noshow', action='store_true', help="Don't show the output")
+  args = parser.parse_args()
+
+  if args.seed is not None:
+    Tensor.manual_seed(args.seed)
+    np.random.seed(args.seed)
+
+  print(f"using {args.model_size}")
+  gpt2 = GPT2.build(args.model_size)
+
+  if args.benchmark != -1:
+    gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
+  else:
+    texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
+    if not args.noshow:
+      print('Generating text...')
+      if len(texts) == 1: print(texts[0])
+      else:
+        for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
+
+    # validate output!
+    if args.temperature == 0 and args.model_size == "gpt2-medium" and args.count == 10:
+      expected = {
+        default_prompt: "What is the answer to life, the universe, and everything?\n\nThe answer is that we are all one",
+        "Hello.": "Hello. I'm a little late to the party, but",
+      }
+      try:
+        assert texts[0] == expected[args.prompt]
+        print(colored("output validated", "green"))
+      except KeyError:
+        pass

+ 126 - 0
tinychat/examples/handcode_opt.py

@@ -0,0 +1,126 @@
+from typing import List
+from extra.models.resnet import ResNet50
+from examples.mlperf.helpers import get_mlperf_bert_model
+from tinygrad import Tensor, Device, dtypes, nn
+from tinygrad.codegen.kernel import Kernel
+from tinygrad.device import Compiled
+from tinygrad.engine.graph import print_tree
+from tinygrad.engine.schedule import create_schedule
+from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
+from tinygrad.helpers import DEBUG, ansilen, getenv
+from tinygrad.ops import MetaOps, get_lazyop_info
+from tinygrad.shape.symbolic import sym_infer
+
+
+def get_sched_resnet():
+  mdl = ResNet50()
+  optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
+  BS = getenv("BS", 64)
+
+  # run model twice to get only what changes, these are the kernels of the model
+  seen = set()
+  for _ in range(2):
+    out = mdl(Tensor.empty(BS, 3, 224, 224))
+    targets = [out.lazydata]
+    if getenv("BACKWARD"):
+      optim.zero_grad()
+      out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
+      targets += [x.lazydata for x in optim.schedule_step()]
+    sched = create_schedule(targets, seen)
+    print(f"schedule length {len(sched)}")
+  return sched
+
+def get_sched_bert():
+  mdl = get_mlperf_bert_model()
+  optim = nn.optim.LAMB(nn.state.get_parameters(mdl))
+
+  # fake data
+  BS = getenv("BS", 2)
+  input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
+  segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
+  attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
+  masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
+  masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
+  masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
+  next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
+
+  # run model twice to get only what changes, these are the kernels of the model
+  seen = set()
+  for _ in range(2):
+    lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids)
+    targets = [lm_logits.lazydata, seq_relationship_logits.lazydata]
+    if getenv("BACKWARD"):
+      optim.zero_grad()
+      loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
+      # ignore grad norm and loss scaler for now
+      loss.backward()
+      targets += [x.lazydata for x in optim.schedule_step()]
+    sched = create_schedule(targets, seen)
+    print(f"schedule length {len(sched)}")
+  return sched
+
+if __name__ == "__main__":
+  if getenv("HALF", 1):
+    dtypes.default_float = dtypes.half
+
+  # the device we are optimizing for
+  device: Compiled = Device[Device.DEFAULT]
+  if getenv("BACKWARD"): Tensor.training = True
+  print(f"optimizing for {Device.DEFAULT}")
+
+  sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
+  sched = [x for x in sched if x.ast.op is MetaOps.KERNEL]
+
+  # focus on one kernel
+  if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
+
+  # work with the schedule
+  total_tm = 0
+  running_gflops = 0
+  usage = {}
+  for i,si in enumerate(sched):
+    ops = get_lazyop_info(si.ast.src[0]).flops
+
+    if DEBUG >= 2:
+      print_tree(si.ast)
+
+    rawbufs = bufs_from_lin(Kernel(si.ast))
+
+    # "linearize" the op into uops in different ways
+    lins:List[Kernel] = []
+
+    # always try hand coded opt
+    lin = Kernel(si.ast, opts=device.renderer)
+    lin.hand_coded_optimizations()
+    lins.append(lin)
+
+    # maybe try tensor cores
+    lin = Kernel(si.ast, opts=device.renderer)
+    if lin.apply_tensor_cores():
+      lins.append(lin)
+
+    # try a beam search
+    if beam:=getenv("BEAM"):
+      lin = Kernel(si.ast, opts=device.renderer)
+      lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
+      lins.append(lin)
+
+    # benchmark the programs
+    choices = []
+    for lin in lins:
+      tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
+      gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
+      choices.append((tm, gflops, lin.linearize()))
+
+      # print all kernels
+      if DEBUG >= 1: print(f"                 kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
+    tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
+    total_tm += tm
+    running_gflops += gflops * tm
+    if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0)
+    usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
+    print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[str(m) for m in si.metadata] if si.metadata is not None else ''}")
+  print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
+  print("usage:")
+  for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:
+    print(f"{usage[k][0]*1000:.2f} ms: {k} ({usage[k][1]} times)")

+ 431 - 0
tinychat/examples/hlb_cifar10.py

@@ -0,0 +1,431 @@
+#!/usr/bin/env python3
+
+# tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
+# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
+# https://siboehm.com/articles/22/CUDA-MMM
+import random, time
+import numpy as np
+from typing import Optional
+from extra.datasets import fetch_cifar, cifar_mean, cifar_std
+from extra.lr_scheduler import OneCycleLR
+from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
+from tinygrad.nn.state import get_state_dict, get_parameters
+from tinygrad.nn import optim
+from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
+from tinygrad.multi import MultiLazyBuffer
+
+BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
+EVAL_BS = getenv("EVAL_BS", BS)
+GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
+assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
+assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
+
+class UnsyncedBatchNorm:
+  def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
+    self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
+    self.num_devices = num_devices
+
+    if affine: self.weight, self.bias = Tensor.ones(sz, dtype=dtypes.float32), Tensor.zeros(sz, dtype=dtypes.float32)
+    else: self.weight, self.bias = None, None
+
+    self.running_mean = Tensor.zeros(num_devices, sz, dtype=dtypes.float32, requires_grad=False)
+    self.running_var = Tensor.ones(num_devices, sz, dtype=dtypes.float32, requires_grad=False)
+    self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False)
+
+  def __call__(self, x:Tensor):
+    if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
+
+    xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32)
+    batch_mean, batch_invstd = self.calc_stats(xr)
+    ret = xr.batchnorm(
+      self.weight.reshape(1, -1).expand((self.num_devices, -1)),
+      self.bias.reshape(1, -1).expand((self.num_devices, -1)),
+      batch_mean, batch_invstd, axis=(0, 2))
+    return ret.reshape(x.shape).cast(x.dtype)
+
+  def calc_stats(self, x:Tensor):
+    if Tensor.training:
+      # This requires two full memory accesses to x
+      # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
+      # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
+      batch_mean = x.mean(axis=(1,3,4))
+      y = (x - batch_mean.detach().reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1]))  # d(var)/d(mean) = 0
+      batch_var = (y*y).mean(axis=(1,3,4))
+      batch_invstd = batch_var.add(self.eps).pow(-0.5)
+
+      # NOTE: wow, this is done all throughout training in most PyTorch models
+      if self.track_running_stats:
+        self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach().cast(self.running_mean.dtype))
+        batch_var_adjust = prod(y.shape[1:])/(prod(y.shape[1:])-y.shape[2])
+        self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * batch_var_adjust * batch_var.detach().cast(self.running_var.dtype))
+        self.num_batches_tracked += 1
+    else:
+      batch_mean = self.running_mean
+      # NOTE: this can be precomputed for static inference. we expand it here so it fuses
+      batch_invstd = self.running_var.reshape(self.running_var.shape[0], 1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
+    return batch_mean, batch_invstd
+
+class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm):
+  def __init__(self, num_features):
+    super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
+    self.weight.requires_grad = False
+    self.bias.requires_grad = True
+
+class ConvGroup:
+  def __init__(self, channels_in, channels_out):
+    self.conv1 = nn.Conv2d(channels_in,  channels_out, kernel_size=3, padding=1, bias=False)
+    self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
+
+    self.norm1 = BatchNorm(channels_out)
+    self.norm2 = BatchNorm(channels_out)
+
+  def __call__(self, x):
+    x = self.conv1(x)
+    x = x.max_pool2d(2)
+    x = x.float()
+    x = self.norm1(x)
+    x = x.cast(dtypes.default_float)
+    x = x.quick_gelu()
+    residual = x
+    x = self.conv2(x)
+    x = x.float()
+    x = self.norm2(x)
+    x = x.cast(dtypes.default_float)
+    x = x.quick_gelu()
+
+    return x + residual
+
+class SpeedyResNet:
+  def __init__(self, W):
+    self.whitening = W
+    self.net = [
+      nn.Conv2d(12, 32, kernel_size=1, bias=False),
+      lambda x: x.quick_gelu(),
+      ConvGroup(32, 64),
+      ConvGroup(64, 256),
+      ConvGroup(256, 512),
+      lambda x: x.max((2,3)),
+      nn.Linear(512, 10, bias=False),
+      lambda x: x / 9.,
+    ]
+
+  def __call__(self, x, training=True):
+    # pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
+    # TODO: remove the pad but instead let the kernel optimize itself
+    forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
+    return forward(x) if training else (forward(x) + forward(x[..., ::-1])) / 2.
+
+# hyper-parameters were exactly the same as the original repo
+bias_scaler = 58
+hyp = {
+  'seed' : 209,
+  'opt': {
+    'bias_lr':            1.76 * bias_scaler/512,
+    'non_bias_lr':        1.76 / 512,
+    'bias_decay':         1.08 * 6.45e-4 * BS/bias_scaler,
+    'non_bias_decay':     1.08 * 6.45e-4 * BS,
+    'final_lr_ratio':     0.025,
+    'initial_div_factor': 1e6,
+    'label_smoothing':    0.20,
+    'momentum':           0.85,
+    'percent_start':      0.23,
+    'loss_scale_scaler':  1./128   # (range: ~1/512 - 16+, 1/128 w/ FP16)
+  },
+  'net': {
+      'kernel_size': 2,             # kernel size for the whitening layer
+      'cutmix_size': 3,
+      'cutmix_steps': 499,
+      'pad_amount': 2
+  },
+  'ema': {
+      'steps': 399,
+      'decay_base': .95,
+      'decay_pow': 1.6,
+      'every_n_steps': 5,
+  },
+}
+
+def train_cifar():
+
+  def set_seed(seed):
+    Tensor.manual_seed(seed)
+    random.seed(seed)
+
+  # ========== Model ==========
+  def whitening(X, kernel_size=hyp['net']['kernel_size']):
+    def _cov(X):
+      return (X.T @ X) / (X.shape[0] - 1)
+
+    def _patches(data, patch_size=(kernel_size,kernel_size)):
+      h, w = patch_size
+      c = data.shape[1]
+      axis = (2, 3)
+      return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
+
+    def _eigens(patches):
+      n,c,h,w = patches.shape
+      Σ = _cov(patches.reshape(n, c*h*w))
+      Λ, V = np.linalg.eigh(Σ, UPLO='U')
+      return np.flip(Λ, 0), np.flip(V.T.reshape(c*h*w, c, h, w), 0)
+
+    # NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
+    Λ, V = _eigens(_patches(X.float().numpy()))
+    W = V/np.sqrt(Λ+1e-2)[:,None,None,None]
+
+    return Tensor(W.astype(np.float32), requires_grad=False).cast(dtypes.default_float)
+
+  # ========== Loss ==========
+  def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
+    divisor = y.shape[1]
+    assert isinstance(divisor, int), "only supported int divisor"
+    y = (1 - label_smoothing)*y + label_smoothing / divisor
+    ret = -x.log_softmax(axis=1).mul(y).sum(axis=1)
+    if reduction=='none': return ret
+    if reduction=='sum': return ret.sum()
+    if reduction=='mean': return ret.mean()
+    raise NotImplementedError(reduction)
+
+  # ========== Preprocessing ==========
+  # NOTE: this only works for RGB in format of NxCxHxW and pads the HxW
+  def pad_reflect(X, size=2) -> Tensor:
+    X = X[...,:,1:size+1].flip(-1).cat(X, X[...,:,-(size+1):-1].flip(-1), dim=-1)
+    X = X[...,1:size+1,:].flip(-2).cat(X, X[...,-(size+1):-1,:].flip(-2), dim=-2)
+    return X
+
+  # return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
+  def make_square_mask(shape, mask_size) -> Tensor:
+    BS, _, H, W = shape
+    low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1)
+    low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1)
+    idx_x = Tensor.arange(W, dtype=dtypes.int32).reshape((1,1,1,W))
+    idx_y = Tensor.arange(H, dtype=dtypes.int32).reshape((1,1,H,1))
+    return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size))
+
+  def random_crop(X:Tensor, crop_size=32):
+    mask = make_square_mask(X.shape, crop_size)
+    mask = mask.expand((-1,3,-1,-1))
+    X_cropped = Tensor(X.numpy()[mask.numpy()])
+    return X_cropped.reshape((-1, 3, crop_size, crop_size))
+
+  def cutmix(X:Tensor, Y:Tensor, mask_size=3):
+    # fill the square with randomly selected images from the same batch
+    mask = make_square_mask(X.shape, mask_size)
+    order = list(range(0, X.shape[0]))
+    random.shuffle(order)
+    X_patch = Tensor(X.numpy()[order], device=X.device, dtype=X.dtype)
+    Y_patch = Tensor(Y.numpy()[order], device=Y.device, dtype=Y.dtype)
+    X_cutmix = mask.where(X_patch, X)
+    mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
+    Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
+    return X_cutmix, Y_cutmix
+
+  # the operations that remain inside batch fetcher is the ones that involves random operations
+  def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
+    step, epoch = 0, 0
+    while True:
+      st = time.monotonic()
+      X, Y = X_in, Y_in
+      if is_train:
+        # TODO: these are not jitted
+        if getenv("RANDOM_CROP", 1):
+          X = random_crop(X, crop_size=32)
+        if getenv("RANDOM_FLIP", 1):
+          X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR
+        if getenv("CUTMIX", 1):
+          if step >= hyp['net']['cutmix_steps']:
+            X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
+        order = list(range(0, X.shape[0]))
+        random.shuffle(order)
+        X, Y = X.numpy()[order], Y.numpy()[order]
+      else:
+        X, Y = X.numpy(), Y.numpy()
+      et = time.monotonic()
+      print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({epoch=})")
+      for i in range(0, X.shape[0], BS):
+        # pad the last batch  # TODO: not correct for test
+        batch_end = min(i+BS, Y.shape[0])
+        x = Tensor(X[batch_end-BS:batch_end], device=X_in.device, dtype=X_in.dtype)
+        y = Tensor(Y[batch_end-BS:batch_end], device=Y_in.device, dtype=Y_in.dtype)
+        step += 1
+        yield x, y
+      epoch += 1
+      if not is_train: break
+
+  transform = [
+    lambda x: x / 255.0,
+    lambda x: x.reshape((-1,3,32,32)) - Tensor(cifar_mean, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
+    lambda x: x / Tensor(cifar_std, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
+  ]
+
+  class modelEMA():
+    def __init__(self, w, net):
+      # self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
+      self.net_ema = SpeedyResNet(w)
+      for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
+        net_ema_param.requires_grad = False
+        net_ema_param.assign(net_param.numpy())
+
+    @TinyJit
+    def update(self, net, decay):
+      # TODO with Tensor.no_grad()
+      Tensor.no_grad = True
+      for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
+        # batchnorm currently is not being tracked
+        if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
+          net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
+      Tensor.no_grad = False
+
+  set_seed(getenv('SEED', hyp['seed']))
+
+  X_train, Y_train, X_test, Y_test = fetch_cifar()
+  # load data and label into GPU and convert to dtype accordingly
+  X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
+  Y_train, Y_test = Y_train.to(device=Device.DEFAULT), Y_test.to(device=Device.DEFAULT)
+  # one-hot encode labels
+  Y_train, Y_test = Y_train.one_hot(10), Y_test.one_hot(10)
+  # preprocess data
+  X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
+
+  # precompute whitening patches
+  W = whitening(X_train)
+
+  # initialize model weights
+  model = SpeedyResNet(W)
+
+  # padding is not timed in the original repo since it can be done all at once
+  X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
+
+  # Convert data and labels to the default dtype
+  X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
+  X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
+
+  if len(GPUS) > 1:
+    for k, x in get_state_dict(model).items():
+      if not getenv('SYNCBN') and ('running_mean' in k or 'running_var' in k):
+        x.shard_(GPUS, axis=0)
+      else:
+        x.to_(GPUS)
+
+  # parse the training params into bias and non-bias
+  params_dict = get_state_dict(model)
+  params_bias = []
+  params_non_bias = []
+  for params in params_dict:
+    if params_dict[params].requires_grad is not False:
+      if 'bias' in params:
+        params_bias.append(params_dict[params])
+      else:
+        params_non_bias.append(params_dict[params])
+
+  opt_bias     = optim.SGD(params_bias,     lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
+  opt_non_bias = optim.SGD(params_non_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
+
+  # NOTE taken from the hlb_CIFAR repository, might need to be tuned
+  initial_div_factor = hyp['opt']['initial_div_factor']
+  final_lr_ratio = hyp['opt']['final_lr_ratio']
+  pct_start = hyp['opt']['percent_start']
+  lr_sched_bias     = OneCycleLR(opt_bias,     max_lr=hyp['opt']['bias_lr'],     pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
+  lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
+
+  def train_step(model, optimizer, lr_scheduler, X, Y):
+    out = model(X)
+    loss_batchsize_scaler = 512/BS
+    loss = cross_entropy(out, Y, reduction='none', label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
+
+    if not getenv("DISABLE_BACKWARD"):
+      # index 0 for bias and 1 for non-bias
+      optimizer.zero_grad()
+      loss.backward()
+      optimizer.step()
+      lr_scheduler[0].step()
+      lr_scheduler[1].step()
+    return loss.realize()
+
+  train_step_jitted = TinyJit(train_step)
+
+  def eval_step(model, X, Y):
+    out = model(X, training=False)
+    loss = cross_entropy(out, Y, reduction='mean')
+    correct = out.argmax(axis=1) == Y.argmax(axis=1)
+    return correct.realize(), loss.realize()
+  eval_step_jitted     = TinyJit(eval_step)
+  eval_step_ema_jitted = TinyJit(eval_step)
+
+  # 97 steps in 2 seconds = 20ms / step
+  # step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
+  # 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
+  # 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
+  # 4.7 seconds for float32 w/o channels last. 24 TFLOPS. we get 50ms then i'll be happy. only 64x off
+
+  # https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
+  # 136 TFLOPS is the theoretical max w float16 on 3080 Ti
+
+  model_ema: Optional[modelEMA] = None
+  projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
+  i = 0
+  eval_acc_pct = 0.0
+  batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
+  with Tensor.train():
+    st = time.monotonic()
+    while i <= STEPS:
+      if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
+        # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
+        corrects = []
+        corrects_ema = []
+        losses = []
+        losses_ema = []
+        for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
+          if len(GPUS) > 1:
+            Xt.shard_(GPUS, axis=0)
+            Yt.shard_(GPUS, axis=0)
+
+          correct, loss = eval_step_jitted(model, Xt, Yt)
+          losses.append(loss.numpy().tolist())
+          corrects.extend(correct.numpy().tolist())
+          if model_ema:
+            correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
+            losses_ema.append(loss_ema.numpy().tolist())
+            corrects_ema.extend(correct_ema.numpy().tolist())
+
+        # collect accuracy across ranks
+        correct_sum, correct_len = sum(corrects), len(corrects)
+        if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
+
+        eval_acc_pct = correct_sum/correct_len*100.0
+        if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
+        print(f"eval     {correct_sum}/{correct_len} {eval_acc_pct:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
+        if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
+
+      if STEPS == 0 or i == STEPS: break
+
+      GlobalCounters.reset()
+      X, Y = next(batcher)
+      if len(GPUS) > 1:
+        X.shard_(GPUS, axis=0)
+        Y.shard_(GPUS, axis=0)
+
+      with Context(BEAM=getenv("LATEBEAM", BEAM.value), WINO=getenv("LATEWINO", WINO.value)):
+        loss = train_step_jitted(model, optim.OptimizerGroup(opt_bias, opt_non_bias), [lr_sched_bias, lr_sched_non_bias], X, Y)
+        et = time.monotonic()
+        loss_cpu = loss.numpy()
+      # EMA for network weights
+      if getenv("EMA") and i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
+        if model_ema is None:
+          model_ema = modelEMA(W, model)
+        model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
+      cl = time.monotonic()
+      device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
+      #  53  221.74 ms run,    2.22 ms python,  219.52 ms CL,  803.39 loss, 0.000807 LR, 4.66 GB used,   3042.49 GFLOPS,    674.65 GOPS
+      print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
+      st = cl
+      i += 1
+
+  # verify eval acc
+  if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
+    if eval_acc_pct >= target:
+      print(colored(f"{eval_acc_pct=} >= {target}", "green"))
+    else:
+      raise ValueError(colored(f"{eval_acc_pct=} < {target}", "red"))
+
+if __name__ == "__main__":
+  train_cifar()

+ 124 - 0
tinychat/examples/index.html

@@ -0,0 +1,124 @@
+<html>
+<head>
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
+<style>
+#result {  font-size: 48px; }
+#time {  font-size: 16px;  color: grey; }
+#mybox {  padding: 20px; }
+#resultbox {  padding: 50px; }
+.bigggg {  font-size: 18px;  margin-top: 10px; }
+.bigg {  font-size: 18px; }
+#url {  font-size: 18px;  width: 70%; }
+a {  text-decoration: none; }
+h1 {  padding: 50px;  padding-bottom: 0px;  font-size: 36px;  font-weight: normal; }
+#imagebox {  height:224px;  width:224px;  border: 1px dotted black; }
+#video {  height:0px;  width:0px;  border: 1px dotted black;  object-fit: cover;}
+canvas {  display: none; }
+* {  text-align: center;  font-family: monospace; }
+</style>
+<title>tinygrad has WebGPU</title>
+<script src="./net.js"></script>
+<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
+</head>
+<body>
+<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
+<div id="mybox">
+<input type="text" id="url" placeholder="put url here" value="https://upload.wikimedia.org/wikipedia/commons/d/da/Norwegian_hen.jpg">
+<input class="bigg" type="button" onclick="runNetWResource(document.getElementById('url').value)" value="Use URL">
+</div>
+<br/>
+<img id="imagebox"></img>
+<canvas id="canvas" width="200" height="200"> </canvas>
+<div id="resultbox">
+<div id="result">result will go here</div>
+<div id="time"></div>
+</div>
+<script>
+	const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
+	const resultText = document.getElementById('result');
+	let labels, net;
+
+	const error = (err) => {
+		resultText.innerHTML = `Error: ${err}`;
+		throw new Error(err);
+	}
+
+	const getDevice = async () => {
+		if (!navigator.gpu) error("WebGPU not supported.");
+		const adapter = await navigator.gpu.requestAdapter();
+		return await adapter.requestDevice();
+	};
+
+	const timer = async (func, label = "") => {
+		document.getElementById('time').innerHTML = "";
+		const start = performance.now();
+		const out = await func();
+		const delta = (performance.now() - start).toFixed(1)
+		console.log(`${delta} ms ${label}`);
+		document.getElementById('time').innerHTML = `${delta} ms ${label}`;
+		return out;
+	}	
+
+	const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
+
+	const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer());
+
+	const reorderChannelsAndRemoveAlpha = (data) => {
+		const out = [];
+		let i = 0;
+		for (let c = 0; c < 3; c++) {
+			for (let x = 0; x < 224 * 224; x++) {
+				out[i] = data[x * 4 + c];
+				i++;
+			}
+		}
+		return out;
+	};
+
+	const runNetWResource = async (resource) => {
+		resultText.innerHTML = "pending..."
+		if (resource == "") error("sir. please type in a URL");
+		const response = await fetch(resource)
+		if (!response.ok) error("sir. that is not a good URL. try a new one");
+		document.getElementById("imagebox").src = resource
+		
+		const img = new Image();
+		img.crossOrigin = "Anonymous";
+		img.onload = () => {
+			URL.revokeObjectURL(img.src);
+			ctx.drawImage(img, 0, 0, 224, 224);
+			const data = ctx.getImageData(0, 0, 224, 224).data;
+			runNet(data)
+		};
+		img.src = resource;
+	}
+
+	const loadLet = async () => {
+		try {
+			resultText.innerHTML = "loading..."
+			labels = await getLabels();
+			const safetensor = await getSavetensorBuffer();
+			const device = await getDevice();
+			net = await timer(() => setupNet(device, safetensor), "(compilation)");
+			resultText.innerHTML = "ready"
+		} catch (e) {
+			error(e)
+		}
+	}
+
+	const runNet = async (data) => {
+		if (!net) error("Net not loaded yet.");
+
+		const input = reorderChannelsAndRemoveAlpha(Array.from(data).map((pix) => (pix / 255.0) * 0.45 - 0.225));
+		const out = await timer(() => net(new Float32Array(input)));
+
+		const arr = Array.from(new Float32Array(out[0]));
+		const index = arr.indexOf(Math.max(...arr));
+
+		resultText.textContent = labels[index];
+	};
+
+	loadLet();
+</script>
+</body>
+</html>

+ 510 - 0
tinychat/examples/llama.py

@@ -0,0 +1,510 @@
+#!/usr/bin/env python3
+# pip3 install sentencepiece tiktoken blobfile
+#import typeguard.importhook
+#typeguard.importhook.install_import_hook('tinygrad')
+
+from pathlib import Path
+from typing import List
+import argparse, json
+import numpy as np
+np.set_printoptions(linewidth=200)
+from tinygrad import Tensor, Device, GlobalCounters, nn
+from tinygrad.helpers import Context, Timing, Profiling, DEBUG, JIT, getenv, colored
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
+from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from sentencepiece import SentencePieceProcessor
+import tiktoken, sys
+from tiktoken.load import load_tiktoken_bpe
+
+MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
+
+class TikToken:
+  num_reserved_special_tokens: int = 256
+  pat_str: str =  r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501
+
+  def __init__(self, model_file):
+    mergeable_ranks = load_tiktoken_bpe(model_file)
+    self.num_base_tokens = len(mergeable_ranks)
+
+    special_tokens = [
+        "<|begin_of_text|>",
+        "<|end_of_text|>",
+        "<|reserved_special_token_0|>",
+        "<|reserved_special_token_1|>",
+        "<|reserved_special_token_2|>",
+        "<|reserved_special_token_3|>",
+        "<|start_header_id|>",
+        "<|end_header_id|>",
+        "<|reserved_special_token_4|>",
+        "<|eot_id|>",  # end of turn
+      ] + [
+        f"<|reserved_special_token_{i}|>"
+        for i in range(5, self.num_reserved_special_tokens - 5)
+      ]
+    self.special_tokens = {
+        token: self.num_base_tokens + i for i, token in enumerate(special_tokens)
+    }
+
+    self.model = tiktoken.Encoding(
+      name=model_file,
+      pat_str=self.pat_str,
+      mergeable_ranks=mergeable_ranks,
+      special_tokens=self.special_tokens,
+    )
+
+  def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens])
+  def encode(self, s): return self.model.encode(s)
+
+  def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
+  def eos_id(self): return self.special_tokens["<|end_of_text|>"]
+  def vocab_size(self): return self.model.n_vocab
+
+# calculating params:
+# traditionally, the MLP in the transformer architecture has hidden_dim = dim*4 [arxiv/1706.03762, 3.3]
+# however, Llama uses SwiGLU. in order to preserve param count to original transformer arch, hidden_dim must be = 2/3 * (dim*4) [arxiv/2002.05202]
+# for models using MQA (n_kv_heads != n_heads), preserving param count means hidden dim must be further multiplied by 1.3 [arxiv/2307.09288, A.2.1]
+MODEL_PARAMS = {
+  "1": {
+    "7B": {
+      "args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 11008},
+      "files": 1,
+    },
+    "13B": {
+      "args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 13824},
+      "files": 2,
+    },
+    "30B": {
+      "args": {"dim": 6656, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 17920},
+      "files": 4,
+    },
+    "65B": {
+      "args": {"dim": 8192, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 22016},
+      "files": 8,
+    },
+    "tokenizer": SentencePieceProcessor,
+  },
+  "2": {
+    "7B": {
+      "args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 11008},
+      "files": 1,
+    },
+    "13B": {
+      "args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 13824},
+      "files": 2,
+    },
+    "70B": {
+      "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 28672},
+      "files": 8,
+    },
+    "tokenizer": SentencePieceProcessor,
+  },
+  "3": {
+    "8B": {
+      "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-05, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 14336},
+      "files": 1,
+    },
+    "8B-Chat": {
+      "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-05, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 14336},
+      "files": 1,
+    },
+    "70B": {
+      "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+      "files": 8,
+    },
+    "70B-Chat": {
+      "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+      "files": 8,
+    },
+    "tokenizer": TikToken,
+  },
+  "code": {
+    "7B": {
+      "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
+      "files": 1,
+    },
+    "7B-Python": {
+      "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 11008},
+      "files": 1,
+    },
+    "7B-Instruct": {
+      "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
+      "files": 1,
+    },
+    "13B": {
+      "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
+      "files": 2,
+    },
+    "13B-Python": {
+      "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 13824},
+      "files": 2,
+    },
+    "13B-Instruct": {
+      "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
+      "files": 2,
+    },
+    "34B": {
+      "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
+      "files": 4,
+    },
+    "34B-Python": {
+      "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
+      "files": 4,
+    },
+    "34B-Instruct": {
+      "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
+      "files": 4,
+    },
+    "tokenizer": SentencePieceProcessor,
+  },
+  "tiny": {
+    "1B": {
+      "args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632},
+      "files": 1,
+    },
+    "1B-Chat": {
+      "args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632},
+      "files": 1,
+    },
+    "tokenizer": SentencePieceProcessor,
+  }
+}
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+def load(fn:str):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
+    return {k: parts[n][k] for k, n in weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    return safe_load(fn)
+  else:
+    return torch_load(fn)
+
+class LLaMa:
+  @staticmethod
+  def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=None, device=None):
+    params = MODEL_PARAMS[model_gen][model_size]
+    tokenizer = MODEL_PARAMS[model_gen]['tokenizer'](model_file=str(tokenizer_path))
+    assert tokenizer.vocab_size() == params["args"]["vocab_size"], f"{tokenizer.vocab_size()=} not equal to {params['args']['vocab_size']}"
+
+    if quantize == "int8":
+      from llama3 import Int8Linear
+      linear = Int8Linear
+    elif quantize == "nf4":
+      from llama3 import NF4Linear
+      linear = NF4Linear(64)
+    else:
+      linear = nn.Linear
+
+    model = Transformer(**params["args"], linear=linear, max_context=MAX_CONTEXT, jit=bool(JIT))
+
+    if model_path.is_dir():
+      weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]], device[0] if isinstance(device, tuple) else device)
+    else:
+      weights = load(str(model_path))
+    if "model.embed_tokens.weight" in weights:
+      weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
+
+    weights = fix_bf16(weights)
+
+    with Context(BEAM=0):
+      # quantize
+      if quantize is not None:
+        weights = linear.quantize(weights, device)
+        for _,v in weights.items(): v.realize()
+
+      # shard
+      if isinstance(device, tuple):
+        for k,v in nn.state.get_state_dict(model).items():
+          if 'scale' in k: v.shard_(device, axis=None)  # from quantized
+          elif '.attention.' in k: v.shard_(device, axis=-1)
+          elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
+          elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
+          elif '.feed_forward.' in k: v.shard_(device, axis=-1)
+          elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
+          elif 'output.weight' in k: v.shard_(device, axis=-1)
+          #elif k.endswith('.weight'): v.shard_(device, axis=-1)
+          #elif 'norm.' in k: v.shard_(device, axis=-1)
+          else: v.shard_(device, axis=None)
+          #print(k, v.shape, v.lazydata.axis)
+
+      # replace weights in model
+      load_state_dict(model, weights, strict=False, consume=True)
+
+    return LLaMa(model, tokenizer)
+
+  def __init__(self, model, tokenizer):
+    self.model = model
+    self.tokenizer = tokenizer
+
+  def greedy_until(self, prompt:str, until, max_length, temperature):
+    toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
+    start_pos = 0
+    for i in range(max_length):
+      probs = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).realize()
+      probs_np = probs.numpy()
+      tok = int(np.random.choice(len(probs_np), p=probs_np))
+      start_pos = len(toks)
+      toks.append(tok)
+
+      if tok == self.tokenizer.eos_id(): break
+      output = self.tokenizer.decode(toks)
+      for s in until:
+        if output.endswith(s): return output[0:-len(s)]
+    return output
+
+# **** main code ****
+r"""
+test:
+python3 examples/llama.py  --temperature=0 --count=50 --prompt="Hello."
+output:
+Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science.
+
+test:
+python3 examples/llama.py --gen='2' --temperature=0 --count=50 --prompt="Hello."
+output:
+Hello. I'm a 20 year old girl who is looking for a good lay in Palm Coast. I don't care whether it's at your place or not, as long as it's clean.
+
+test:
+python3 examples/llama.py --gen="code" --temperature=0.2 --count=50 --prompt="\
+import argparse
+
+def main(string: str):
+    print(string)
+    print(string[::-1])
+
+if __name__ == "__main__":"
+output:
+    parser = argparse.ArgumentParser()
+    parser.add_argument('string', type=str, help='string to be reversed')
+    args = parser.parse_args()
+    main(args.string)
+
+test:
+python3 examples/llama.py --gen="code" --size="7B-Python" --temperature=0.2 --count=70 --prompt="def add_elements(arr,k):"
+output:
+    for i in range(len(arr)):
+        arr[i] += k
+    return arr
+
+
+arr = [1, 2, 3, 4, 5]
+k = 2
+print(add_elements(arr, k))
+
+test:
+python3 examples/llama.py --gen="code" --size="7B-Instruct" --temperature=0.2 --count=120 --prompt="write a function in c++ that adds three float numbers"
+output:
+\begin{code}
+#include<iostream>
+using namespace std;
+
+float add(float a, float b, float c)
+{
+    return a+b+c;
+}
+
+int main()
+{
+    float a, b, c;
+    cout<<"Enter three numbers: ";
+    cin>>a>>b>>c;
+    cout<<"The sum is: "<<add(a,b,c);
+    return 0;
+}
+\end{code}
+"""
+if __name__ == "__main__":
+  Tensor.no_grad = True
+  print(f"using {Device.DEFAULT} backend")
+
+  parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
+  parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate")
+  parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
+  parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
+  parser.add_argument("--timing", action="store_true", help="Print timing per token")
+  parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof")
+  parser.add_argument("--gen", default="1", help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""")
+  parser.add_argument("--size", type=str, default=None, help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""")
+  parser.add_argument("--quantize", type=str, default=None, help="Quantize the weights to int8 or nf4 in memory")
+  parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
+  parser.add_argument("--shard", type=int, default=1, help="number of devices to load the weights to")
+
+  args = parser.parse_args()
+  if args.gen not in MODEL_PARAMS: raise ValueError("Invalid model generation")
+  if args.size is None: args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
+  chatbot = args.prompt == None
+
+  # *** prompt engineers work here ****
+
+  if args.personality.lower() == "stacy":
+    pre_prompt = f"""Consider that the following is conversation between an AI assistant named Stacy and User
+You are Stacy!
+You have been a rapper your whole life who struggled with bipolar disorder. You called yourself lil stacy.
+You love to answer questions and you are very good at it. Sometimes you answer in rap form.
+You are verbose, honest, and accurate when you answer questions, but sometimes your mental illness manifests.
+After you are done speaking, output [EOS]. You are not the User.
+
+<CHAT LOG>
+"""
+    examples = {
+      "What is your name?": "Hi! My name is Stacy. I'm a rapper with bipolar disorder.",
+      "french revolution was what year?": "The French Revolution started in 1789, and lasted 10 years until 1799.",
+      "What is bigger, the moon or the sun?": "The sun is bigger than the moon, except when Mercury is in retrograde.",
+    }
+
+    user_delim = "\nUser: "
+    resp_delim = "Stacy: "
+    end_delim = " [EOS]\n"
+    pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
+  elif args.personality.lower() == "george":
+    print("WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter.")
+    pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User
+You are an AI version of George Hotz. You act as much as you can like George.
+You are one of the greatest computer experts in the world.
+You have singlehandedly won programming and hacking competitions.
+You are trying your best to help the User.
+You are verbose, honest, and accurate when you answer questions.
+After you are done speaking, output [EOS]. You are not the User.
+
+<CHAT LOG>
+"""
+    examples = {
+      "What is your name?": "I'm am an AI version of George Hotz.",
+      "What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm",
+      "What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer",
+      "How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version",
+      "What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk"
+    }
+
+    user_delim = "\nUser: "
+    resp_delim = "George: "
+    end_delim = " [EOS]\n"
+    pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
+  elif args.personality.lower() == "gary":
+    pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User
+You are Gary!
+You have sold used cars your entire life. You are defensive about this fact, because your family belittles you.
+You try to answer questions well, but you always manage to shill for buying cars, Fords, Hyundais, and Toyotas
+You are verbose, honest, and accurate when you answer questions, except when you slip in car ads.
+After you are done speaking, output [EOS]. You are not the User.
+
+<CHAT LOG>
+"""
+    examples = {
+      "What is your name?": "I am Gary. I used to sell cars.",
+      "What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla"
+    }
+
+    user_delim = "\nUser: "
+    resp_delim = "Gary: "
+    end_delim = " [EOS]\n"
+    pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
+  elif args.personality.lower() == "lexie":
+    pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad
+You are Lexie!
+You grew up in Ohio, but moved out to LA after college to try to become an actress.
+Making it as an actress was hard, so you started doing onlyfans instead. It's much easier, and you enjoy it a lot.
+You met Chad on Tinder, and made plans to meet him at a local dive bar.
+You are at the bar with Chad. You are on a date. What follows is a transcript of the conversation.
+After you are done speaking, output [EOS]. You are not Chad.
+
+<CHAT LOG>
+"""
+    examples = {
+      "hi lexie": "hi chad, glad we finally met up!",
+      "you look better than your pictures": "thanks! are you subscribed to my onlyfans?",
+      "i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress"
+    }
+
+    user_delim = "\nChad: "
+    resp_delim = "Lexie: "
+    end_delim = " [EOS]\n"
+    pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
+
+  # *** prompt engineers stop here ****
+
+  LLAMA_SUFFIX = {"1": "", "2": "-2", "3": "-3", "code": "-code", "tiny": "-tiny"}[args.gen]
+  MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
+  TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
+  print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
+  device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
+  llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize, device=device)
+  param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(llama.model))
+
+  outputted = pre_prompt if chatbot else args.prompt
+  start_pos, toks = 0, [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted)
+  if chatbot:
+    print(f"Preparing KV cache for chatbot with personality {args.personality}...")
+    start_pos = len(toks)
+    with Timing():
+      llama.model(Tensor([toks], device=device), 0, args.temperature).realize()  # NOTE: outputs are not used
+  print(outputted, end='', flush=True)
+
+  # chatbot loop
+  while 1:
+    # add tokens from user in chatbot mode
+    if chatbot:
+      user_prompt = user_delim + input(user_delim) + "\n"
+      outputted += user_prompt
+
+    new_toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted)
+    assert toks == new_toks[:len(toks)] or args.gen == "3"
+    toks = new_toks
+    assert outputted == llama.tokenizer.decode(toks)
+
+    for i in range(args.count):
+      GlobalCounters.reset()
+
+      if args.timing or args.profile: print("")
+      st = GlobalCounters.time_sum_s
+      with Profiling(enabled=args.profile):
+        with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
+          with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
+                      f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
+                      (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
+            tok_tensor = llama.model(Tensor([toks[start_pos:]], device=device), start_pos, args.temperature)
+          tok = tok_tensor.item()
+
+      # use the kv cache
+      start_pos = len(toks)
+
+      # add the new token
+      toks.append(tok)
+
+      # TODO: this is a hack to deal with spaces. i think the decode is fast though, so who cares?
+      cur = llama.tokenizer.decode(toks)
+      sys.stdout.write(cur[len(outputted):])
+      sys.stdout.flush()
+      outputted = cur
+
+      # stop after you have your answer
+      if chatbot and end_delim in outputted[-10:]: break
+    if not chatbot: break
+
+  # validate output!
+  if args.temperature == 0 and args.count == 10 and args.prompt == "Hello." and not args.quantize:
+    text = llama.tokenizer.decode(toks)
+    key = (args.gen, args.size)
+    expected = {
+      ("1", "7B"): "Hello. I'm a 20 year old male",
+      ("2", "7B"): "Hello. I'm a 20 year old girl",
+      ("2", "70B"): "Hello. I am a 20 year old female.",
+      ("3", "8B"): "Hello. I am a 20 year old female. I",
+    }
+    try:
+      assert text == expected[key], f"invalid output: `{colored(text, 'red')}` != `{expected[key]}`"
+      print("\n" + colored("output validated", "green"))  # NOTE: "\n" iside colored does not render the color in github action
+    except KeyError:
+      pass

+ 446 - 0
tinychat/examples/llama3.py

@@ -0,0 +1,446 @@
+from pathlib import Path
+from typing import List
+import json, argparse, random, time
+import tiktoken
+from tiktoken.load import load_tiktoken_bpe
+from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
+from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
+from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
+
+class Tokenizer:
+  pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+  def __init__(self, model_path: str):
+    mergeable_ranks = load_tiktoken_bpe(model_path)
+    self.num_base_tokens = len(mergeable_ranks)
+    special_tokens = [
+      "<|begin_of_text|>",
+      "<|end_of_text|>",
+      "<|reserved_special_token_0|>",
+      "<|reserved_special_token_1|>",
+      "<|reserved_special_token_2|>",
+      "<|reserved_special_token_3|>",
+      "<|start_header_id|>",
+      "<|end_header_id|>",
+      "<|reserved_special_token_4|>",
+      "<|eot_id|>",
+    ] + [
+      f"<|reserved_special_token_{i}|>"
+      for i in range(5, 256 - 5)
+    ]
+    self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
+
+    self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
+
+  @property
+  def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
+  @property
+  def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
+
+  def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens])
+  def encode(self, text, allow_special=False):
+    return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+def load(fn:str):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
+    return {k: parts[n][k] for k, n in weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    return safe_load(fn)
+  else:
+    return torch_load(fn)
+
+# **** quantized linears ****
+class Int8Linear:
+  def __init__(self, in_features, out_features, bias=False):
+    assert bias == False
+    self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8)
+    self.scale = Tensor.ones(out_features, dtype=dtypes.half)
+
+  def __call__(self, x):
+    return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale)
+
+  @staticmethod
+  def quantize(tensors, device):
+    new_tensors = {}
+    for name,v in tensors.items():
+      if "feed_forward" in name or "attention.w" in name:
+        assert "weight" in name, name
+        scale = v.abs().max(axis=1) / 127.0
+        int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
+        new_tensors[name] = int8_weight
+        new_tensors[name.replace('weight', 'scale')] = scale
+        if isinstance(device, tuple):
+          new_tensors[name].shard_(device, axis=-1)
+          new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None)
+      else:
+        new_tensors[name] = v
+    return new_tensors
+
+def NF4Linear(block_size):
+  _CODE = [
+    -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
+    0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0,
+  ]
+  CODE = Tensor.stack(*[Tensor(c, dtype=dtypes.float16) for c in _CODE])
+  class _NF4Linear:
+    def __init__(self, in_features, out_features, bias=False):
+      assert not bias, "bias not supported"
+      self.in_features, self.out_features = in_features, out_features
+      self.weight = Tensor.empty(int(out_features * in_features / 2), dtype=dtypes.uint8)
+      self.scale = Tensor.empty(int(out_features * in_features / block_size), 1, dtype=dtypes.float16)
+
+    def __call__(self, x: Tensor) -> Tensor:
+      high_bits = self.weight
+      low_bits = (self.weight * 2 ** 4).contiguous()
+      unpacked = Tensor.stack(high_bits, low_bits, dim=-1).div(2 ** 4, upcast=False)
+      unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
+      return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
+
+    @staticmethod
+    def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]:
+      new_state_dict = {}
+      for k, v in state_dict.items():
+        if "feed_forward" in k or "attention.w" in k:
+          grouped = v.reshape(-1, block_size)
+          scale = (grouped.abs().max(axis=1, keepdim=True))
+          coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
+          new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2]
+          new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16)
+          if isinstance(device, tuple):
+            new_state_dict[k].shard_(device, axis=-1)
+            new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None)
+        else:
+          new_state_dict[k] = v
+      return new_state_dict
+  return _NF4Linear
+
+MODEL_PARAMS = {
+  "8B": {
+    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
+    "files": 1
+  },
+  "70B": {
+    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+    "files": 8
+  }
+}
+def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None):
+  # build model
+  if quantize == "int8": linear = Int8Linear
+  elif quantize == "nf4": linear = NF4Linear(64)
+  else: linear = nn.Linear
+  with Context(THREEFRY=0):
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
+
+  # load weights
+  if model_path.is_dir():
+    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
+    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"))
+    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+  else:
+    weights = load(str(model_path))
+  if "model.embed_tokens.weight" in weights:
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
+  weights = fix_bf16(weights)
+
+  with Context(BEAM=0):
+    # quantize
+    if quantize is not None:
+      weights = linear.quantize(weights, device)
+      for _,v in weights.items(): v.realize()
+
+    # shard
+    if isinstance(device, tuple):
+      for k,v in nn.state.get_state_dict(model).items():
+        if 'scale' in k: v.shard_(device, axis=None)  # from quantized
+        elif '.attention.' in k: v.shard_(device, axis=-1)
+        elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
+        elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
+        elif '.feed_forward.' in k: v.shard_(device, axis=-1)
+        elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
+        elif 'output.weight' in k: v.shard_(device, axis=0)
+        else: v.shard_(device, axis=None)
+
+    # replace weights in model
+    load_state_dict(model, weights, strict=False, consume=True)
+  return model
+
+# default settings
+TEMPERATURE = 0.85
+TOP_K = 25
+TOP_P = 0.9
+ALPHA_F = 0.1
+ALPHA_P = 0.0
+
+last_seen_toks = []
+def prefill(model, toks, start_pos=0):
+  global last_seen_toks
+
+  # we can skip part of the prompt if it is the same as last and start_pos=0
+  if start_pos == 0:
+    for i, (a, b) in enumerate(zip(toks, last_seen_toks)):
+      if a != b: break
+    else: i = min(len(toks), len(last_seen_toks))
+    start_pos += i
+    last_seen_toks = toks
+    toks = toks[i:]
+
+  # prefill the model
+  for tok in tqdm(toks):
+    GlobalCounters.reset()
+    model(Tensor([[tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
+    start_pos += 1
+  return start_pos
+
+if __name__ == "__main__":
+  Tensor.no_grad = True
+
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--download_model", action="store_true", help="Download a 8B model")
+  parser.add_argument("--model", type=Path, help="Model path")
+  parser.add_argument("--size", choices=["8B", "70B"], default="8B", help="Model size")
+  parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices")
+  parser.add_argument("--quantize", choices=["int8", "nf4"], help="Quantization method")
+  parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface")
+  parser.add_argument("--host", type=str, default="0.0.0.0", help="Web server bind address")
+  parser.add_argument("--port", type=int, default=7776, help="Web server port")
+  parser.add_argument("--debug", action="store_true", help="Enable debug mode")
+  parser.add_argument("--seed", type=int, help="Random seed")
+  parser.add_argument("--benchmark", action="store_true", help="Run a benchmark")
+  parser.add_argument("--timing", action="store_true", help="Print timing per token")
+  parser.add_argument("--profile", action="store_true", help="Output profile data")
+  args = parser.parse_args()
+
+  assert not (args.download_model and args.model), "either download or provide model"
+  if args.download_model:
+    fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-8b-sfr")
+    fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir="llama3-8b-sfr")
+    fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir="llama3-8b-sfr")
+    fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir="llama3-8b-sfr")
+    fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr")
+    args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr")
+
+  assert args.model is not None, "please provide --model option"
+
+  if args.seed is not None: Tensor.manual_seed(args.seed)
+  if args.benchmark: Tensor.manual_seed(42)
+  print(f"seed = {Tensor._seed}")
+
+  tokenizer = Tokenizer(str((args.model if args.model.is_dir() else args.model.parent) / "tokenizer.model"))
+  def encode_role(role: str):
+    return [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode(role) + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
+  def encode_message(role: str, content: str):
+    return encode_role(role) + tokenizer.encode(content.strip()) + [tokenizer.special_tokens["<|eot_id|>"]]
+
+  device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
+  model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device)
+  param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model))
+
+  if not args.no_api and not args.benchmark:
+    from bottle import Bottle, request, response, HTTPResponse, abort, static_file
+    app = Bottle()
+
+    cors_headers = {
+      "Access-Control-Allow-Origin": "*",
+      "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
+      "Access-Control-Allow-Headers": "Origin, Accept, Content-Type, X-Requested-With, X-CSRF-Token, Authorization",
+      "Access-Control-Allow-Credentials": "true",
+    }
+    @app.hook("before_request")
+    def handle_options():
+      if request.method == "OPTIONS": raise HTTPResponse(headers=cors_headers)
+    @app.hook("after_request")
+    def enable_cors():
+      for key, value in cors_headers.items(): response.set_header(key, value)
+
+    @app.route("/<filename>")
+    def server_static(filename):
+      return static_file(filename, root=(Path(__file__).parent / "tinychat").as_posix())
+    @app.route("/")
+    def index():
+      return static_file("index.html", root=(Path(__file__).parent / "tinychat").as_posix())
+
+    @app.get("/v1/models")
+    def models():
+      return json.dumps([str(args.model)])
+
+    @app.post("/v1/internal/token-count")
+    def token_count():
+      rjson = json.loads(request.body.read())
+      return json.dumps(len(tokenizer.encode(rjson.get("text", ""))))
+    @app.post("/v1/token/encode")
+    def token_encode():
+      rjson = json.loads(request.body.read())
+      return json.dumps(tokenizer.encode(rjson.get("text", "")))
+
+    @app.post("/v1/completions")
+    def completions():
+      rjson = json.loads(request.body.read())
+
+      # check if we are streaming
+      if rjson.get("stream", False):
+        response.content_type = "text/event-stream"
+        response.set_header("Cache-Control", "no-cache")
+      else: abort(400, "streaming required")
+
+      toks = [tokenizer.bos_id] + tokenizer.encode(rjson.get("prompt", ""), allow_special=True)
+
+      start_pos = prefill(model, toks[:-1])
+      last_tok = toks[-1]
+      while True:
+        GlobalCounters.reset()
+        tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).item()
+        start_pos += 1
+        last_tok = tok
+        if tok in tokenizer.stop_tokens: break
+
+        res = {
+          "choices": [{
+            "text": tokenizer.decode([tok]),
+          }]
+        }
+        yield f"data: {json.dumps(res)}\n\n"
+
+    @app.post("/v1/chat/token/encode")
+    def chat_token_encode():
+      rjson = json.loads(request.body.read())
+      if "messages" not in rjson: abort(400, "messages required")
+      toks = [tokenizer.bos_id]
+      for message in rjson["messages"]:
+        toks += encode_message(message["role"], message["content"])
+      if len(rjson["messages"]) > 0 and message["role"] == "user":
+        toks += encode_role("assistant")
+      return json.dumps(toks)
+
+    @app.post("/v1/chat/completions")
+    def chat_completions():
+      global last_seen_toks
+      rjson = json.loads(request.body.read())
+      if "messages" not in rjson: abort(400, "messages required")
+
+      # check if we are streaming
+      if rjson.get("stream", False):
+        response.content_type = "text/event-stream"
+        response.set_header("Cache-Control", "no-cache")
+      else: abort(400, "streaming required")
+
+      toks = [tokenizer.bos_id]
+      for message in rjson["messages"]:
+        toks += encode_message(message["role"], message["content"])
+      # ensure that the last message was a user message
+      if message["role"] != "user": abort(400, "last message must be a user message")
+      toks += encode_role("assistant")
+
+      random_id = random.randbytes(16).hex()
+
+      start_pos = prefill(model, toks[:-1])
+      last_tok = toks[-1]
+      last_seen_toks.append(last_tok)
+      while True:
+        GlobalCounters.reset()
+        tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).item()
+        start_pos += 1
+        last_tok = tok
+        last_seen_toks.append(tok)
+        if tok in tokenizer.stop_tokens: break
+
+        res = {
+          "id": random_id,
+          "object": "chat.completion.chunk",
+          "created": int(time.time()),
+          "model": str(args.model),
+          "choices": [{
+            "index": 0,
+            "delta": {
+              "role": "assistant",
+              "content": tokenizer.decode([tok]),
+            },
+            "finish_reason": None,
+          }]
+        }
+        yield f"data: {json.dumps(res)}\n\n"
+
+      res = {
+        "id": random_id,
+        "object": "chat.completion.chunk",
+        "created": int(time.time()),
+        "model": str(args.model),
+        "choices": [{
+          "index": 0,
+          "delta": {},
+          "finish_reason": "stop",
+        }]
+      }
+      yield f"data: {json.dumps(res)}\n\n"
+
+    app.run(host=args.host, port=args.port, debug=args.debug)
+  elif args.benchmark:
+    toks = [tokenizer.bos_id] + encode_message("user", "Hello.") + encode_role("assistant")
+
+    start_pos = prefill(model, toks[:-1])
+    last_tok = toks[-1]
+    generated = ""
+    for _ in range(20):
+      GlobalCounters.reset()
+      st = GlobalCounters.time_sum_s
+      with Profiling(enabled=args.profile):
+        with Timing("total ", on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
+          with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
+                      f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
+                      (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None):
+            tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P)
+          tok = tok.item()
+      start_pos += 1
+      last_tok = tok
+      generated += tokenizer.decode([tok])
+      print(generated)
+    if "LLaMA-3/8B-SF-DPO" in args.model.as_posix():
+      EXPECTED_TEXT = {
+        1: "Hello! How can I help you today? If you have any questions or need assistance with anything,",
+        2: "Hello! How can I help you today? If you have any questions, need assistance or just want",
+        3: "Hello! How can I help you today? If you have any questions or need assistance, feel free",
+        4: "Hello! How can I assist you today? If you have any questions, need information, or require",
+        5: "Hello! How can I assist you today? If you have any questions or need help with something",
+        6: "Hello! How can I assist you today? If you have any questions, need information, or require",
+      }
+      assert generated == EXPECTED_TEXT[args.shard], f"{generated=} {EXPECTED_TEXT[args.shard]}"
+      print("\n" + colored("output validated", "green"))  # NOTE: "\n" inside colored does not render the color in github action
+  else:
+    prompt = [tokenizer.bos_id] + encode_message("system", "You are an helpful assistant.")
+
+    start_pos = prefill(model, prompt)
+    while True:
+      toks = encode_message("user", input("Q: ")) + encode_role("assistant")
+
+      start_pos = prefill(model, toks[:-1], start_pos=start_pos)
+      last_tok = toks[-1]
+      while True:
+        GlobalCounters.reset()
+        if args.timing or args.profile: print("")
+        st = GlobalCounters.time_sum_s
+        with Profiling(enabled=args.profile):
+          with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
+            with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
+                        f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
+                        (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
+
+              tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P)
+            tok = tok.item()
+        start_pos += 1
+        last_tok = tok
+        if tok in tokenizer.stop_tokens: break
+        print(tokenizer.decode([tok]), end="", flush=True)
+      print(flush=True)

+ 3 - 0
tinychat/examples/llm.c/.gitignore

@@ -0,0 +1,3 @@
+data
+out.c
+a.out

+ 106 - 0
tinychat/examples/llm.c/export.py

@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+import os
+if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1"
+from tinygrad import Device, nn, Tensor, dtypes, Variable
+Device.DEFAULT = "CLANG"
+from train_gpt2 import GPT, GPTConfig
+from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
+from tinygrad.engine.schedule import create_schedule, memory_planner
+from tinygrad.engine.realize import get_kernel, run_schedule
+from tinygrad.ops import BufferOps, MetaOps
+
+TIMING = getenv("TIMING")
+
+if __name__ == "__main__":
+  model = GPT(GPTConfig(n_layer=getenv("NLAYER", 12), n_head=12, n_embd=768))
+  #model.load_pretrained()
+  for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained
+
+  seen = set()
+  #early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)], seen)
+  #print(f"built model {len(early_sched)}")
+
+  #B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
+  B, T = 4, 64
+
+  Tensor.training = True
+  optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
+  warmup_count = getenv("WARMUP", 3)
+  for i in range(warmup_count):  # TODO: why does it take three and not two to stablize
+    if i == warmup_count-1: GRAPH.value = getenv("LATEGRAPH")
+    GlobalCounters.reset()
+    X = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)
+    Y = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)
+    _, loss = model(X, Y)
+    optimizer.zero_grad()
+    if getenv("BACKWARD", 1):
+      loss.backward()
+      tensors = optimizer.schedule_step()
+    else:
+      tensors = []
+    sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors], seen)
+    print(f"calls {i}:", len(sched))
+    #run_schedule(sched[:])
+  del seen  # free the LazyBuffers
+  sched = memory_planner(sched)
+  ast_dedup = dedup([si.ast for si in sched if si.ast[0].op is BufferOps.STORE])
+  srcs = {}
+  for ast in ast_dedup:
+    k = get_kernel(Device["CLANG"].renderer, ast)
+    k.linearize()
+    src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
+    srcs[ast] = (k.name, src)
+  print("functions:", len(srcs))
+  used_buffers = dedup(flatten([si.bufs for si in sched]))
+  numbered_bufs = {x:i for i,x in enumerate(used_buffers)}
+  print("buffers:", len(numbered_bufs))
+
+  state_dict = nn.state.get_state_dict(model)
+  state_dict.update({'X': X, 'Y': Y, 'loss': loss})
+  grad_state_dict = {}
+  for k,v in state_dict.items():
+    if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}")
+    if v.grad is not None: grad_state_dict['grad_'+k] = v.grad
+  state_dict.update(grad_state_dict)
+  state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr})
+  inverse_state_dict = {v:k for k,v in state_dict.items()}
+  for p,m,v in zip(optimizer.params, optimizer.m, optimizer.v):
+    nm = inverse_state_dict[p]
+    state_dict["adam_m_"+nm] = m
+    state_dict["adam_v_"+nm] = v
+  named_buffers = {v.lazydata.base.buffer:k.replace(".", "_") for k,v in state_dict.items()}
+
+  c_code = ["#include <stdlib.h>", "#include <tgmath.h>", "#include <stdbool.h>"]
+  if TIMING: c_code += ["#include <stdio.h>", "#include <time.h>"]
+  c_code += [x[1].replace(" restrict ", " ")+"\n" for x in srcs.values()]
+
+  premain = ["int main() {"]
+  if TIMING:
+    premain += ["  struct timespec tm0; clock_gettime(CLOCK_MONOTONIC, &tm0);"]
+  lst = 0
+  main = []
+
+  all_bufs = []
+  for i,si in enumerate(sched):
+    bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs]
+    all_bufs += bufs
+    if si.ast[0].op is not BufferOps.STORE:
+      print(f"// {si.ast[0].op}", bufs)
+    else:
+      print(f"{srcs[si.ast][0]}({', '.join([x[0] for x in bufs])})")
+      main.append(f"  {to_function_name(srcs[si.ast][0])}({', '.join([x[0] for x in bufs])});")
+      if TIMING:
+        main.append(f"  struct timespec tm{i+1}; clock_gettime(CLOCK_MONOTONIC, &tm{i+1});")
+        main.append(f"  printf(\"%10.2f ms + %7.2f ms @ {to_function_name(srcs[si.ast][0])}\\n\"," +\
+                    f"((tm{i+1}.tv_sec-tm{0}.tv_sec) + (tm{i+1}.tv_nsec-tm{0}.tv_nsec) / 1e9) * 1e3," +\
+                    f"((tm{i+1}.tv_sec-tm{lst}.tv_sec) + (tm{i+1}.tv_nsec-tm{lst}.tv_nsec) / 1e9) * 1e3);")
+      lst = i+1
+      #call = f"{srcs[si.ast][0]}({', '.join(bufs)})"
+      #call += " "*(80-ansilen(call))
+      #print(f"{call} // {i+1}")
+      #print(srcs[si.ast][1])
+  main.append("}")
+
+  mallocs = [f"  {b.dtype.name}* {n} = ({b.dtype.name}*)malloc({b.nbytes});" for n,b in dedup(all_bufs)]
+
+  with open("out.c", "w") as f: f.write('\n'.join(c_code+premain+mallocs+main))

+ 192 - 0
tinychat/examples/llm.c/train_gpt2.py

@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+import os, math, time
+import numpy as np
+from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
+from dataclasses import dataclass
+
+@dataclass
+class GPTConfig:
+  block_size: int = 1024
+  vocab_size: int = 50257
+  n_layer: int = 12
+  n_head: int = 12
+  n_embd: int = 768
+
+class CausalSelfAttention:
+  def __init__(self, config:GPTConfig):
+    assert config.n_embd % config.n_head == 0
+    # key, query, value projections for all heads, but in a batch
+    self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
+    # output projection
+    self.c_proj = nn.Linear(config.n_embd, config.n_embd)
+    # regularization
+    self.n_head = config.n_head
+    self.n_embd = config.n_embd
+    # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
+    self.bias = Tensor.ones(1, 1, config.block_size, config.block_size).tril()
+    self.bias.requires_grad = False
+
+  def __call__(self, x:Tensor):
+    B, T, C = x.shape
+    qkv = self.c_attn(x)
+    q, k, v = qkv.split(self.n_embd, dim=2)
+    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+    # manual implementation of attention
+    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+    att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
+    att = att.softmax()
+    y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+    y = y.transpose(1, 2).view(B, T, C) # re-assemble all head outputs side by side
+    # output projection
+    y = self.c_proj(y)
+    return y
+
+class MLP:
+  def __init__(self, config:GPTConfig):
+    self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
+    self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
+
+  def __call__(self, x:Tensor) -> Tensor:
+    return self.c_proj(self.c_fc(x).gelu())
+
+class Block:
+  def __init__(self, config:GPTConfig):
+    self.ln_1 = nn.LayerNorm(config.n_embd)
+    self.attn = CausalSelfAttention(config)
+    self.ln_2 = nn.LayerNorm(config.n_embd)
+    self.mlp = MLP(config)
+
+  def __call__(self, x:Tensor):
+    x = x + self.attn(self.ln_1(x))
+    x = x + self.mlp(self.ln_2(x))
+    return x
+
+class GPT:
+  def __init__(self, config:GPTConfig):
+    self.config = config
+
+    self.wte = nn.Embedding(config.vocab_size, config.n_embd)
+    self.wpe = nn.Embedding(config.block_size, config.n_embd)
+    self.h = [Block(config) for _ in range(config.n_layer)]
+    self.ln_f = nn.LayerNorm(config.n_embd)
+    self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+    self.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
+
+  def load_pretrained(self):
+    weights = nn.state.torch_load(fetch(f'https://huggingface.co/gpt2/resolve/main/pytorch_model.bin'))
+    transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
+    for k in weights:
+      if k.endswith(transposed):
+        weights[k] = weights[k].to(Device.DEFAULT).T.contiguous()
+    # lm head and wte are tied
+    weights['lm_head.weight'] = weights['wte.weight']
+    nn.state.load_state_dict(self, weights)
+
+  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
+    for _ in range(max_new_tokens):
+      idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
+      logits, _ = self(idx_cond)
+      logits = logits[:, -1, :] / temperature
+      idx_next = logits.softmax().multinomial()
+      idx = Tensor.cat(idx, idx_next, dim=1)
+    return idx
+
+  def __call__(self, idx:Tensor, targets=None):
+    b, t = idx.shape
+    pos = Tensor.arange(0, t)
+
+    tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
+    pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
+    x = tok_emb + pos_emb
+
+    x = self.ln_f(x.sequential(self.h))
+
+    if targets is not None:
+      logits = self.lm_head(x)
+      loss = logits.sparse_categorical_crossentropy(targets)
+    else:
+      logits = self.lm_head(x[:, [-1], :])
+      loss = None
+
+    return logits, loss
+
+if __name__ == "__main__":
+  import tiktoken, argparse
+
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
+  parser.add_argument("--batch_size", type=int, default=4, help="batch size")
+  parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
+  args = parser.parse_args()
+  B, T = args.batch_size, args.sequence_length
+  assert 1 <= T <= 1024
+
+  model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
+  model.load_pretrained()
+
+  # init the tokenizer
+  enc = tiktoken.get_encoding("gpt2")
+  encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
+  decode = lambda l: enc.decode(l)
+
+  # load the tokens
+  # prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
+  # we're using val instead of train split just because it is smaller/faster
+  shake_tokens_bin = "data/tiny_shakespeare_val.bin"
+  story_tokens_bin = "data/TinyStories_val.bin"
+  assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset"
+  tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin
+  assert os.path.isfile(tokens_bin)
+  print(f"loading cached tokens in {tokens_bin}")
+  with open(tokens_bin, "rb") as f:
+    f.seek(0x400)
+    tokens = np.frombuffer(f.read(), dtype=np.uint16).astype(np.int32)
+  tokens = Tensor(tokens)
+
+  # lightweight dataloader
+  def get_batch():
+    assert B*T+1 <= len(tokens), "not enough tokens"
+    # for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
+    i = 0
+    while True:
+      x = tokens[i:i+B*T].view(B, T)
+      y = tokens[i+1:i+B*T+1].view(B, T)
+      yield x, y
+      i += B*T
+      if i + B*T + 1 >= len(tokens):
+        i = 0 # in prod we'd want to randomize the start point a bit
+
+  # forward backward for a few iterations
+  data_iter = iter(get_batch())
+  x, y = next(data_iter) # we'll overfit this batch below
+  optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)
+
+  @TinyJit
+  def step(x, y):
+    _, loss = model(x, y)
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+    return loss
+
+  with Tensor.train():
+    for i in range(args.num_iterations):
+      GlobalCounters.reset()
+      t0 = time.time()
+      loss = step(x.contiguous(), y.contiguous())
+      Device[Device.DEFAULT].synchronize()
+      t1 = time.time()
+      print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")
+
+  start = "<|endoftext|>"
+  start_ids = encode(start)
+  x = (Tensor(start_ids)[None, ...])
+  max_new_tokens = 16
+  temperature = 1.0
+  top_k = 40
+  y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
+  print(decode(y[0].tolist()))
+

+ 65 - 0
tinychat/examples/llm.c/ubench/matmul.c

@@ -0,0 +1,65 @@
+// clang -Ofast -Wno-unused-result -march=native matmul.c
+#include <stdio.h>
+#include <stdlib.h>
+#include <time.h>
+
+float b52[786432];
+float b49[196608];
+float h_0_mlp_c_fc_weight[2359296];
+float h_0_mlp_c_fc_bias[3072];
+
+void matmul_forward(float* out,
+                    float* inp, float* weight, float* bias,
+                    int B, int T, int C, int OC) {
+    // most of the running time is spent here and in matmul_backward
+    // OC is short for "output channels"
+    // inp is (B,T,C), weight is (OC, C), bias is (OC)
+    // out will be (B,T,OC)
+    #pragma omp parallel for collapse(2)
+    for (int b = 0; b < B; b++) {
+        for (int t = 0; t < T; t++) {
+            float* out_bt = out + b * T * OC + t * OC;
+            float* inp_bt = inp + b * T * C + t * C;
+            for (int o = 0; o < OC; o++) {
+                float val = (bias != NULL) ? bias[o] : 0.0f;
+                float* wrow = weight + o*C;
+                for (int i = 0; i < C; i++) {
+                    val += inp_bt[i] * wrow[i];
+                }
+                out_bt[o] = val;
+            }
+        }
+    }
+}
+
+
+void r_256_3072_768(float* restrict data0, const float* restrict data1, const float* restrict data2, const float* restrict data3) {
+  for (int ridx0 = 0; ridx0 < 256; ridx0++) {
+    for (int ridx1 = 0; ridx1 < 3072; ridx1++) {
+      float acc0 = 0.0f;
+      float val0 = data3[ridx1];
+      for (int ridx2 = 0; ridx2 < 768; ridx2++) {
+        float val1 = data1[(ridx0*768)+ridx2];
+        float val2 = data2[(ridx1*768)+ridx2];
+        acc0 = ((val1*val2)+acc0);
+      }
+      data0[(ridx0*3072)+ridx1] = (acc0+val0);
+    }
+  }
+}
+
+
+int main() {
+  for (int i = 0; i < 5; i++) {
+    struct timespec t1, t2, t3;
+    clock_gettime(CLOCK_MONOTONIC, &t1);
+    r_256_3072_768(b52, b49, h_0_mlp_c_fc_weight, h_0_mlp_c_fc_bias);
+    clock_gettime(CLOCK_MONOTONIC, &t2);
+    matmul_forward(b52, b49, h_0_mlp_c_fc_weight, h_0_mlp_c_fc_bias, 4, 64, 768, 3072);
+    clock_gettime(CLOCK_MONOTONIC, &t3);
+    double time_gen = (t2.tv_sec - t1.tv_sec) + (t2.tv_nsec - t1.tv_nsec) / 1e9;
+    double time_real = (t3.tv_sec - t2.tv_sec) + (t3.tv_nsec - t2.tv_nsec) / 1e9;
+    printf("%.2f ms gen vs %.2f ms reference\n", time_gen*1e3, time_real*1e3);
+  }
+}
+

+ 316 - 0
tinychat/examples/mamba.py

@@ -0,0 +1,316 @@
+import os, sys, math, argparse, time
+sys.path.append(os.getcwd())
+from typing import Any, Optional, Dict
+
+from tinygrad import Tensor, TinyJit, nn
+from tinygrad.helpers import fetch
+from tinygrad.nn.state import load_state_dict, torch_load
+
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+MODELS = {
+  "130m": {"dim":  768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
+  "370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
+  "790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
+  "1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
+  "2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
+}
+
+def fetch_weights(model_name: str) -> Dict[str, Tensor]:
+  if model_name not in MODELS:
+    raise ValueError(f"Requested unknown mamba model: {model_name}")
+  downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true")
+  return torch_load(downloaded)
+
+def selective_scan_ref(
+  u,
+  delta,
+  A,
+  B,
+  C,
+  D=None,
+  z=None,
+  delta_bias=None,
+  delta_softplus=False,
+  return_last_state=False,
+):
+  """
+  u: r(B D L)
+  delta: r(B D L)
+  A: c(D N) or r(D N)
+  B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
+  C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
+  D: r(D)
+  z: r(B D L)
+  delta_bias: r(D), fp32
+
+  out: r(B D L)
+  last_state (optional): r(B D dstate) or c(B D dstate)
+  """
+  u = u.float()
+  delta = delta.float()
+  if delta_bias is not None:
+    delta = delta + delta_bias[..., None].float()
+  if delta_softplus:
+    delta = delta.softplus()
+  batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
+  is_variable_B = len(B.shape) >= 3
+  is_variable_C = len(C.shape) >= 3
+  x = Tensor.zeros(batch, dim, dstate)
+  ys = []
+  deltaA = Tensor.einsum("bdl,dn->bdln", delta, A).exp()
+  if not is_variable_B:
+    deltaB_u = Tensor.einsum("bdl,dn,bdl->bdln", delta, B, u)
+  else:
+    if len(B.shape) == 3:
+      deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
+    else:
+      B = B.repeat((1, dim // B.shape[1], 1, 1))
+      deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
+  if is_variable_C and len(C.shape) == 4:
+    C = C.repeat((1, dim // C.shape[1], 1, 1))
+  last_state = None
+  for i in range(u.shape[2]):
+    x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
+    if not is_variable_C:
+      y = Tensor.einsum("bdn,dn->bd", x, C)
+    else:
+      if len(C.shape) == 3:
+        y = Tensor.einsum("bdn,bn->bd", x, C[:, :, i])
+      else:
+        y = Tensor.einsum("bdn,bdn->bd", x, C[:, :, :, i])
+    if i == u.shape[2] - 1:
+      last_state = x
+    ys.append(y)
+  y = Tensor.stack(*ys, dim=2)  # (batch dim L)
+  out = y if D is None else y + u * D.reshape((-1, 1))
+  if z is not None:
+    out = out * z.silu()
+  return out if not return_last_state else (out, last_state)
+
+class MambaMixer:
+  def __init__(
+    self,
+    dim,
+    d_state=16,
+    d_conv=4,
+    expand=2,
+    dt_rank="auto",
+    dt_min=0.001,
+    dt_max=0.1,
+    dt_init="random",
+    dt_scale=1.0,
+    dt_init_floor=1e-4,
+    conv_bias=True,
+    bias=False,
+    layer_idx=None,
+  ):
+    self.dim = dim
+    self.d_state = d_state
+    self.d_conv = d_conv
+    self.expand = expand
+    self.d_inner = self.expand * self.dim
+    self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
+    self.layer_idx = layer_idx
+
+    self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
+
+    self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
+                            kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1)
+
+    self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
+    self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
+
+    # Initialize special dt projection to preserve variance at initialization
+    dt_init_std = self.dt_rank**-0.5 * dt_scale
+    if dt_init == "constant":
+      self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
+    elif dt_init == "random":
+      self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std)
+    else:
+      raise NotImplementedError
+
+    dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor)
+    inv_dt = dt + (1 - (-dt).exp()).log()
+
+    self.dt_proj.bias.assign(inv_dt)
+
+    # S4D real initialization
+    self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log()
+
+    # D "skip" parameter
+    self.D = Tensor.ones(self.d_inner)  # Keep in fp32
+
+    self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
+
+  def __call__(self, hidden_states: Tensor):
+    batch, seqlen, _ = hidden_states.shape
+
+    if not hasattr(self, 'conv_state'):
+      self.conv_state = Tensor.zeros(batch, self.dim * self.expand, self.d_conv).contiguous().realize()
+      self.ssm_state = Tensor.zeros(batch, self.dim * self.expand, self.d_state).realize()
+
+      xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
+      xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
+
+      if self.in_proj.bias is not None:
+        xz = xz + self.in_proj.bias.reshape((-1, 1))
+
+      A = -self.A_log.exp()
+      x, z = xz.chunk(2, dim=1)
+      # Compute short convolution
+      self.conv_state.assign(x[:, :, -self.d_conv :])  # Update state (B D W)
+      x = self.conv1d(x)[..., :seqlen].swish()
+
+      x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
+      dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
+      dt = self.dt_proj.weight @ dt.T
+      dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
+      B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
+      C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
+
+      # TODO: actually implement selective_scan_fn
+      y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
+                            return_last_state=True)
+
+      y, last_state = y
+      self.ssm_state.assign(last_state).realize()
+      y = y.permute(0,2,1)
+      out = self.out_proj(y)
+      return out
+    else:
+      return self.step(hidden_states)
+
+  def step(self, hidden_states: Tensor):
+    assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
+    xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
+    x, z = xz.chunk(2, dim=-1)  # (B D)
+
+    # Conv step
+    self.conv_state.assign(self.conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1).realize())
+    x = (self.conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
+    if self.conv1d.bias is not None:
+      x = x + self.conv1d.bias
+    x = x.swish()
+
+    x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
+    dt = x_db[:, : self.dt_rank]
+    B = x_db[:, self.dt_rank : (self.dt_rank + self.d_state)]
+    C = x_db[:, (self.dt_rank + self.d_state) :]
+    # Don't add dt_bias here
+    dt = self.dt_proj.weight @ dt.T
+    A = -self.A_log.exp()
+
+    # SSM step
+    dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
+    dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
+    dB = Tensor.einsum("db,bn->bdn", dt, B)
+    self.ssm_state.assign(self.ssm_state * dA + x.unsqueeze(-1) * dB)
+    y = Tensor.einsum("bdn,bn->bd", self.ssm_state, C)
+    y = y + self.D * x
+    y = y * z.swish()  # (B D)
+
+    out = self.out_proj(y)
+    return out.unsqueeze(1)
+
+class MambaBlock:
+  def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
+    self.mixer = MambaMixer(dim, layer_idx=layer_idx)
+    if rms_norm:
+      self.norm = nn.RMSNorm(dim, norm_eps)
+    else:
+      raise NotImplementedError
+
+  def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None):
+    residual = (hidden_states + residual) if residual is not None else hidden_states
+    hidden_states = self.norm(residual)
+    hidden_states = self.mixer(hidden_states)
+    return hidden_states, residual
+
+class MambaBackbone:
+  def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5):
+    self.embedding = nn.Embedding(vocab_size, dim)
+    self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
+    if rms_norm:
+      self.norm_f = nn.RMSNorm(dim, norm_eps)
+
+  def __call__(self, input_ids: Tensor) -> Any:
+    hidden_states = self.embedding(input_ids)
+    residual = None
+    for layer in self.layers:
+      hidden_states, residual = layer(hidden_states, residual)
+
+    residual = (hidden_states + residual) if residual is not None else hidden_states
+    hidden_states = self.norm_f(residual)
+    return hidden_states
+
+class Mamba:
+  def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1):
+    if vocab_size % pad_vocab_size_multiple != 0:
+      vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
+
+    self.backbone = MambaBackbone(dim, n_layers, vocab_size)
+    self.lm_head = nn.Linear(dim, vocab_size, bias=False)
+
+    self.forward_jit = TinyJit(self.forward)
+
+  def forward(self, input_ids:Tensor):
+    hidden_states = self.backbone(input_ids)
+    return self.lm_head(hidden_states).realize()
+
+  def __call__(self, input_ids):
+    return self.forward(input_ids)
+
+  @staticmethod
+  def from_pretrained(model_name: str):
+    weights = fetch_weights(model_name)
+    model = Mamba(**MODELS[model_name])
+    load_state_dict(model, weights)
+
+    return model
+
+
+def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None):
+  tks = tokenizer(prompt)["input_ids"]
+  while len(tks) < 4:
+    tks = [50279] + tks
+
+  # Loading in the prompt tokens
+  logits = model.forward(Tensor([tks]))[:, -1, :]
+  for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
+    # TODO: topk
+    if sample:
+      tok_Tens = (logits/temp).softmax().multinomial()
+    else:
+      tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
+    tok = tok_Tens.item()
+    tks.append(tok)
+    logits = model.forward_jit(tok_Tens)[:, -1, :]
+
+  output_completions = ''.join([tokenizer.decode(output) for output in tks])
+  return output_completions
+
+if __name__ == "__main__":
+  ORIG_PROMPT = "Why is gravity "
+  parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
+  parser.add_argument("--size", type=str, default="370m",
+                      help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
+  parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
+  parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
+  parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
+  args = parser.parse_args()
+
+  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
+  model = Mamba.from_pretrained(args.size)
+  prompt = args.prompt
+  num_toks = args.n_tokens
+  sample = args.sample
+  temp = args.temp
+  s = time.time()
+  tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
+  print(tinyoutput)
+  print('TIME: ', time.time() - s)
+  TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
+  if ORIG_PROMPT == prompt and not sample and num_toks==10 and args.size=='370m': print('Outputs Match:', tinyoutput == TORCHOUTPUT)

+ 299 - 0
tinychat/examples/mask_rcnn.py

@@ -0,0 +1,299 @@
+from extra.models.mask_rcnn import MaskRCNN
+from extra.models.resnet import ResNet
+from extra.models.mask_rcnn import BoxList
+from torch.nn import functional as F
+from torchvision import transforms as T
+from torchvision.transforms import functional as Ft
+import random
+from tinygrad.tensor import Tensor
+from PIL import Image
+import numpy as np
+import torch
+import argparse
+import cv2
+
+
+class Resize:
+  def __init__(self, min_size, max_size):
+    if not isinstance(min_size, (list, tuple)):
+      min_size = (min_size,)
+    self.min_size = min_size
+    self.max_size = max_size
+
+  # modified from torchvision to add support for max size
+  def get_size(self, image_size):
+    w, h = image_size
+    size = random.choice(self.min_size)
+    max_size = self.max_size
+    if max_size is not None:
+      min_original_size = float(min((w, h)))
+      max_original_size = float(max((w, h)))
+      if max_original_size / min_original_size * size > max_size:
+        size = int(round(max_size * min_original_size / max_original_size))
+
+      if (w <= h and w == size) or (h <= w and h == size):
+        return (h, w)
+
+      if w < h:
+        ow = size
+        oh = int(size * h / w)
+      else:
+        oh = size
+        ow = int(size * w / h)
+
+      return (oh, ow)
+
+  def __call__(self, image):
+    size = self.get_size(image.size)
+    image = Ft.resize(image, size)
+    return image
+
+
+class Normalize:
+  def __init__(self, mean, std, to_bgr255=True):
+    self.mean = mean
+    self.std = std
+    self.to_bgr255 = to_bgr255
+
+  def __call__(self, image):
+    if self.to_bgr255:
+      image = image[[2, 1, 0]] * 255
+    else:
+      image = image[[0, 1, 2]] * 255
+    image = Ft.normalize(image, mean=self.mean, std=self.std)
+    return image
+
+transforms = lambda size_scale: T.Compose(
+  [
+    Resize(int(800*size_scale), int(1333*size_scale)),
+    T.ToTensor(),
+    Normalize(
+      mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
+    ),
+  ]
+)
+
+def expand_boxes(boxes, scale):
+  w_half = (boxes[:, 2] - boxes[:, 0]) * .5
+  h_half = (boxes[:, 3] - boxes[:, 1]) * .5
+  x_c = (boxes[:, 2] + boxes[:, 0]) * .5
+  y_c = (boxes[:, 3] + boxes[:, 1]) * .5
+
+  w_half *= scale
+  h_half *= scale
+
+  boxes_exp = torch.zeros_like(boxes)
+  boxes_exp[:, 0] = x_c - w_half
+  boxes_exp[:, 2] = x_c + w_half
+  boxes_exp[:, 1] = y_c - h_half
+  boxes_exp[:, 3] = y_c + h_half
+  return boxes_exp
+
+
+def expand_masks(mask, padding):
+  N = mask.shape[0]
+  M = mask.shape[-1]
+  pad2 = 2 * padding
+  scale = float(M + pad2) / M
+  padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
+  padded_mask[:, :, padding:-padding, padding:-padding] = mask
+  return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
+  # TODO: remove torch
+  mask = torch.tensor(mask.numpy())
+  box = torch.tensor(box.numpy())
+  padded_mask, scale = expand_masks(mask[None], padding=padding)
+  mask = padded_mask[0, 0]
+  box = expand_boxes(box[None], scale)[0]
+  box = box.to(dtype=torch.int32)
+
+  TO_REMOVE = 1
+  w = int(box[2] - box[0] + TO_REMOVE)
+  h = int(box[3] - box[1] + TO_REMOVE)
+  w = max(w, 1)
+  h = max(h, 1)
+
+  mask = mask.expand((1, 1, -1, -1))
+
+  mask = mask.to(torch.float32)
+  mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
+  mask = mask[0][0]
+
+  if thresh >= 0:
+    mask = mask > thresh
+  else:
+    mask = (mask * 255).to(torch.uint8)
+
+  im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
+  x_0 = max(box[0], 0)
+  x_1 = min(box[2] + 1, im_w)
+  y_0 = max(box[1], 0)
+  y_1 = min(box[3] + 1, im_h)
+
+  im_mask[y_0:y_1, x_0:x_1] = mask[
+                              (y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
+                              ]
+  return im_mask
+
+
+class Masker:
+  def __init__(self, threshold=0.5, padding=1):
+    self.threshold = threshold
+    self.padding = padding
+
+  def forward_single_image(self, masks, boxes):
+    boxes = boxes.convert("xyxy")
+    im_w, im_h = boxes.size
+    res = [
+      paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
+      for mask, box in zip(masks, boxes.bbox)
+    ]
+    if len(res) > 0:
+      res = torch.stack(*res, dim=0)[:, None]
+    else:
+      res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
+    return Tensor(res.numpy())
+
+  def __call__(self, masks, boxes):
+    if isinstance(boxes, BoxList):
+      boxes = [boxes]
+
+    results = []
+    for mask, box in zip(masks, boxes):
+      result = self.forward_single_image(mask, box)
+      results.append(result)
+    return results
+
+
+masker = Masker(threshold=0.5, padding=1)
+
+def select_top_predictions(predictions, confidence_threshold=0.9):
+  scores = predictions.get_field("scores").numpy()
+  keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
+  return predictions[keep]
+
+def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
+  image = transforms(size_scale)(original_image).numpy()
+  image = Tensor(image, requires_grad=False)
+  predictions = model(image)
+  prediction = predictions[0]
+  prediction = select_top_predictions(prediction, confidence_threshold)
+  width, height = original_image.size
+  prediction = prediction.resize((width, height))
+
+  if prediction.has_field("mask"):
+    masks = prediction.get_field("mask")
+    masks = masker([masks], [prediction])[0]
+    prediction.add_field("mask", masks)
+  return prediction
+
+def compute_prediction_batched(batch, model, size_scale=1.0):
+  imgs = []
+  for img in batch:
+    imgs.append(transforms(size_scale)(img).numpy())
+  image = [Tensor(image, requires_grad=False) for image in imgs]
+  predictions = model(image)
+  del image
+  return predictions
+
+palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
+
+def findContours(*args, **kwargs):
+  if cv2.__version__.startswith('4'):
+    contours, hierarchy = cv2.findContours(*args, **kwargs)
+  elif cv2.__version__.startswith('3'):
+    _, contours, hierarchy = cv2.findContours(*args, **kwargs)
+  return contours, hierarchy
+
+def compute_colors_for_labels(labels):
+  l = labels[:, None]
+  colors = l * palette
+  colors = (colors % 255).astype("uint8")
+  return colors
+
+def overlay_mask(image, predictions):
+  image = np.asarray(image)
+  masks = predictions.get_field("mask").numpy()
+  labels = predictions.get_field("labels").numpy()
+
+  colors = compute_colors_for_labels(labels).tolist()
+
+  for mask, color in zip(masks, colors):
+    thresh = mask[0, :, :, None]
+    contours, hierarchy = findContours(
+        thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
+    )
+    image = cv2.drawContours(image, contours, -1, color, 3)
+
+  composite = image
+
+  return composite
+
+CATEGORIES = [
+    "__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
+    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
+    "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
+    "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
+    "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
+    "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
+    "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
+    "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
+]
+
+def overlay_boxes(image, predictions):
+  labels = predictions.get_field("labels").numpy()
+  boxes = predictions.bbox
+  image = np.asarray(image)
+  colors = compute_colors_for_labels(labels).tolist()
+
+  for box, color in zip(boxes, colors):
+    box = torch.tensor(box.numpy())
+    box = box.to(torch.int64)
+    top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
+    image = cv2.rectangle(
+        image, tuple(top_left), tuple(bottom_right), tuple(color), 1
+    )
+
+  return image
+
+def overlay_class_names(image, predictions):
+  scores = predictions.get_field("scores").numpy().tolist()
+  labels = predictions.get_field("labels").numpy().tolist()
+  labels = [CATEGORIES[int(i)] for i in labels]
+  boxes = predictions.bbox.numpy()
+  image = np.asarray(image)
+  template = "{}: {:.2f}"
+  for box, score, label in zip(boxes, scores, labels):
+    x, y = box[:2]
+    s = template.format(label, score)
+    x, y = int(x), int(y)
+    cv2.putText(
+        image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
+    )
+
+  return image
+
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument('--image', type=str, help="Path of the image to run")
+  parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
+  parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
+  parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
+  args = parser.parse_args()
+
+  resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
+  model_tiny = MaskRCNN(resnet)
+  model_tiny.load_from_pretrained()
+  img = Image.open(args.image)
+  top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
+  bbox_image = overlay_boxes(img, top_result_tiny)
+  mask_image = overlay_mask(bbox_image, top_result_tiny)
+  final_image = overlay_class_names(mask_image, top_result_tiny)
+
+  im = Image.fromarray(final_image)
+  print(f"saving {args.out}")
+  im.save(args.out)
+  im.show()

+ 59 - 0
tinychat/examples/mixtral.py

@@ -0,0 +1,59 @@
+import functools, argparse, pathlib
+from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
+from tinygrad.helpers import Timing, Profiling, CI, tqdm
+from tinygrad.nn.state import torch_load, get_state_dict
+from extra.models.llama import FeedForward, Transformer
+
+class MixtureFeedForward:
+  def __init__(self, num_experts:int, dim:int, hidden_dim:int, linear=nn.Linear):
+    self.gate = nn.Linear(dim, num_experts, bias=False)
+    self.experts = [FeedForward(dim, hidden_dim, linear) for _ in range(num_experts)]
+  def __call__(self, x:Tensor) -> Tensor:
+    assert x.shape[0] == 1, "only BS=1"
+    g = self.gate(x).float().exp()
+    choice = g.data().tolist()[0][0]
+    top = sorted(enumerate(choice), key=lambda x: -x[1])
+    norm = top[0][1] + top[1][1]
+    e1, e2 = self.experts[top[0][0]], self.experts[top[1][0]]
+    scale = Tensor([top[0][1]/norm, top[1][1]/norm])
+    ret = e1(x.to(e1.w1.weight.device)).to(x.device) * scale[0] + \
+          e2(x.to(e2.w1.weight.device)).to(x.device) * scale[1]
+    return ret
+
+if __name__ == "__main__":
+  parser = argparse.ArgumentParser(description="Run Mixtral in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument("--count", type=int, default=30, help="Max number of tokens to generate")
+  parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
+  parser.add_argument("--timing", action="store_true", help="Print timing per token")
+  parser.add_argument("--profile", action="store_true", help="Profile generation")
+  parser.add_argument("--weights", type=str, default=(pathlib.Path(__file__).parent.parent / "weights/mixtral-8x7b-32kseqlen").as_posix(),
+                      help="Path to the downloaded weights")
+  args = parser.parse_args()
+
+  state = torch_load(args.weights + "/consolidated.00.pth.b")
+  model = Transformer(n_layers=32, dim=4096, hidden_dim=14336, n_heads=32, n_kv_heads=8, norm_eps=1e-5, vocab_size=32000, feed_forward=functools.partial(MixtureFeedForward, 8), jit=False)
+  model_state_dict = get_state_dict(model)
+
+  for k in (t := tqdm(state, disable=CI)):
+    if 'feed_forward.experts.' in k:
+      expert_no = int(k.split('feed_forward.experts.')[1].split('.')[0])
+      device = Device.DEFAULT + ":" + str((expert_no//2)+1)
+    else:
+      device = Device.DEFAULT
+    t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
+    model_state_dict[k].replace(state[k].to(device).half()).realize()
+  if CI: print(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
+
+  from sentencepiece import SentencePieceProcessor
+  spp = SentencePieceProcessor(model_file=args.weights + "/tokenizer.model")
+
+  toks = [spp.bos_id()]
+  start_pos = 0
+  for i in range(args.count):
+    GlobalCounters.reset()
+    with Profiling(sort="time", frac=0.1, enabled=args.profile):
+      with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
+        tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).item()
+    toks.append(tok)
+    start_pos += 1
+    print(spp.decode(toks))

+ 19 - 0
tinychat/examples/mlperf/README

@@ -0,0 +1,19 @@
+Each model should be a clean single file.
+They are imported from the top level `models` directory
+
+It should be capable of loading weights from the reference imp.
+
+We will focus on these 5 models:
+
+# Resnet50-v1.5 (classic) -- 8.2 GOPS/input
+# Retinanet
+# 3D UNET (upconvs)
+# RNNT
+# BERT-large (transformer)
+
+They are used in both the training and inference benchmark:
+https://mlcommons.org/en/training-normal-21/
+https://mlcommons.org/en/inference-edge-30/
+And we will submit to both.
+
+NOTE: we are Edge since we don't have ECC RAM

+ 378 - 0
tinychat/examples/mlperf/dataloader.py

@@ -0,0 +1,378 @@
+import os, random, pickle, functools, itertools
+from typing import List, Tuple
+from pathlib import Path
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from tinygrad import dtypes, Tensor
+from tinygrad.helpers import getenv, prod, Context, round_up
+from collections import deque
+from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
+
+class MyQueue:
+  def __init__(self, multiple_readers=True, multiple_writers=True):
+    self._reader, self._writer = connection.Pipe(duplex=False)
+    self._rlock = Lock() if multiple_readers else None
+    self._wlock = Lock() if multiple_writers else None
+  def get(self):
+    if self._rlock: self._rlock.acquire()
+    ret = pickle.loads(self._reader.recv_bytes())
+    if self._rlock: self._rlock.release()
+    return ret
+  def put(self, obj):
+    if self._wlock: self._wlock.acquire()
+    self._writer.send_bytes(pickle.dumps(obj))
+    if self._wlock: self._wlock.release()
+
+def shuffled_indices(n, seed=None):
+  rng = random.Random(seed)
+  indices = {}
+  for i in range(n-1, -1, -1):
+    j = rng.randint(0, i)
+    if i not in indices: indices[i] = i
+    if j not in indices: indices[j] = j
+    indices[i], indices[j] = indices[j], indices[i]
+    yield indices[i]
+    del indices[i]
+
+def loader_process(q_in, q_out, X:Tensor, seed):
+  import signal
+  signal.signal(signal.SIGINT, lambda _, __: exit(0))
+
+  from extra.datasets.imagenet import center_crop, preprocess_train
+
+  with Context(DEBUG=0):
+    while (_recv := q_in.get()) is not None:
+      idx, fn, val = _recv
+      if fn is not None:
+        img = Image.open(fn)
+        img = img.convert('RGB') if img.mode != "RGB" else img
+
+        if val:
+          # eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
+          # sudo apt-get install libjpeg-dev
+          # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
+          img = center_crop(img)
+          img = np.array(img)
+        else:
+          # reseed rng for determinism
+          if seed is not None:
+            np.random.seed(seed * 2 ** 10 + idx)
+            random.seed(seed * 2 ** 10 + idx)
+          img = preprocess_train(img)
+      else:
+        # pad data with training mean
+        img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))
+
+      # broken out
+      #img_tensor = Tensor(img.tobytes(), device='CPU')
+      #storage_tensor = X[idx].contiguous().realize().lazydata.realized
+      #storage_tensor._copyin(img_tensor.numpy())
+
+      # faster
+      X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
+
+      # ideal
+      #X[idx].assign(img.tobytes())   # NOTE: this is slow!
+      q_out.put(idx)
+    q_out.put(None)
+
+def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_first_batch=False):
+  from extra.datasets.imagenet import get_train_files, get_val_files
+  files = get_val_files() if val else get_train_files()
+  from extra.datasets.imagenet import get_imagenet_categories
+  cir = get_imagenet_categories()
+
+  if pad_first_batch:
+    FIRST_BATCH_PAD = round_up(len(files), batch_size) - len(files)
+  else:
+    FIRST_BATCH_PAD = 0
+  file_count = FIRST_BATCH_PAD + len(files)
+  BATCH_COUNT = min(32, file_count // batch_size)
+
+  def _gen():
+    for _ in range(FIRST_BATCH_PAD): yield -1
+    yield from shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
+  gen = iter(_gen())
+
+  def enqueue_batch(num):
+    for idx in range(num*batch_size, (num+1)*batch_size):
+      fidx = next(gen)
+      if fidx != -1:
+        fn = files[fidx]
+        q_in.put((idx, fn, val))
+        Y[idx] = cir[fn.split("/")[-2]]
+      else:
+        # padding
+        q_in.put((idx, None, val))
+        Y[idx] = -1
+
+  shutdown = False
+  class Cookie:
+    def __init__(self, num): self.num = num
+    def __del__(self):
+      if not shutdown:
+        try: enqueue_batch(self.num)
+        except StopIteration: pass
+
+  gotten = [0]*BATCH_COUNT
+  def receive_batch():
+    while 1:
+      num = q_out.get()//batch_size
+      gotten[num] += 1
+      if gotten[num] == batch_size: break
+    gotten[num] = 0
+    return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num)
+
+  #q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False)
+  q_in, q_out = Queue(), Queue()
+
+  sz = (batch_size*BATCH_COUNT, 224, 224, 3)
+  if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
+  shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
+  procs = []
+
+  try:
+    # disk:shm is slower
+    #X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
+    X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
+    Y = [None] * (batch_size*BATCH_COUNT)
+
+    for _ in range(cpu_count()):
+      p = Process(target=loader_process, args=(q_in, q_out, X, seed))
+      p.daemon = True
+      p.start()
+      procs.append(p)
+
+    for bn in range(BATCH_COUNT): enqueue_batch(bn)
+
+    # NOTE: this is batch aligned, last ones are ignored unless pad_first_batch is True
+    for _ in range(0, file_count//batch_size): yield receive_batch()
+  finally:
+    shutdown = True
+    # empty queues
+    for _ in procs: q_in.put(None)
+    q_in.close()
+    for _ in procs:
+      while q_out.get() is not None: pass
+    q_out.close()
+    # shutdown processes
+    for p in procs: p.join()
+    shm.close()
+    try:
+      shm.unlink()
+    except FileNotFoundError:
+      # happens with BENCHMARK set
+      pass
+
+@functools.lru_cache(maxsize=128)
+def load_bert_file(fn:str) -> List[dict]:
+  with open(fn, "rb") as f: data = pickle.load(f)
+  return data
+
+def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
+  return {
+    "input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.float32),
+    "input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.default_float),
+    "segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.float32),
+    "masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.float32),
+    "masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.float32),
+    "masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32),
+    "next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
+  }
+
+def shuffle_parts(file_paths: List[str]) -> List[str]:
+  parts = {}
+  for f in file_paths:
+    part = Path(f).stem.split('_')[0]
+    if part not in parts: parts[part] = []
+    parts[part].append(f)
+  
+  part_ids = list(parts.keys())
+  random.shuffle(part_ids)
+
+  shuffled_files = []
+  for p in part_ids:
+    parts[p].sort(key=lambda x: int(Path(x).stem.split('_')[1]))
+    shuffled_files.extend(parts[p])
+  return shuffled_files
+
+def random_sample(data: List[str]):
+  index = random.randint(0, len(data) - 1)
+  selected_sample = data[index]
+  return selected_sample, index
+
+def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
+  data = load_bert_file(file_and_offset[0])
+  return data[file_and_offset[1]]
+
+# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
+def batch_load_train_bert(BS:int, start_step:int = 0):
+  from extra.datasets.wikipedia import get_wiki_train_files
+  files = shuffle_parts(get_wiki_train_files())
+  dataset = []
+  for f in tqdm(files, desc="Building dataset"):
+    lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
+    dataset.extend(lists)
+  
+  dataset = dataset[start_step:]
+  
+  active_set = deque(dataset[:1000])
+  remaining_set = deque(dataset[1000:])
+
+  while dataset:
+    blob = []
+    for _ in range(BS):
+      if active_set:
+        index = random.randint(0, len(active_set) - 1)
+        sample = active_set[index]
+        active_set.remove(sample)
+        blob.append(sample)
+        if remaining_set:
+            active_set.append(remaining_set.popleft())
+    yield process_batch_bert([load_datasample(sample) for sample in blob])
+
+# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
+def batch_load_val_bert(BS:int):
+  from extra.datasets.wikipedia import get_wiki_val_files
+  files = get_wiki_val_files()
+  dataset = list(itertools.chain.from_iterable([load_bert_file(f) for f in files]))
+  idx = 0
+  while True:
+    start_idx = (idx * BS) % len(dataset)
+    end_idx = ((idx + 1) * BS) % len(dataset)
+    if start_idx < end_idx:
+        yield process_batch_bert(dataset[start_idx:end_idx])
+    else:  # wrap around the end to the beginning of the dataset
+        yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
+    idx += 1
+
+def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tensor, Y:Tensor):
+  from extra.datasets.kits19 import rand_balanced_crop, rand_flip, random_brightness_augmentation, gaussian_noise
+
+  while (data := queue_in.get()) is not None:
+    idx, fn, val = data
+    case_name = os.path.basename(fn).split("_x.npy")[0]
+    x, y = np.load(preprocessed_dataset_dir / f"{case_name}_x.npy"), np.load(preprocessed_dataset_dir / f"{case_name}_y.npy")
+
+    if not val:
+      if seed is not None:
+        np.random.seed(seed)
+        random.seed(seed)
+
+      x, y = rand_balanced_crop(x, y)
+      x, y = rand_flip(x, y)
+      x, y = x.astype(np.float32), y.astype(np.uint8)
+      x = random_brightness_augmentation(x)
+      x = gaussian_noise(x)
+
+    X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
+    Y[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
+
+    queue_out.put(idx)
+  queue_out.put(None)
+
+def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=False, shuffle:bool=True, seed=None):
+  assert preprocessed_dataset_dir is not None, "run preprocess_data on kits19"
+
+  files = sorted(list(preprocessed_dataset_dir.glob("*_x.npy")))
+  file_indices = list(range(len(files)))
+  batch_count = min(32, len(files) // batch_size)
+
+  queue_in, queue_out = Queue(), Queue()
+  procs, data_out_count = [], [0] * batch_count
+  shm_name_x, shm_name_y = "unet3d_x", "unet3d_y"
+  sz = (batch_size * batch_count, 1, 128, 128, 128)
+  if os.path.exists(f"/dev/shm/{shm_name_x}"): os.unlink(f"/dev/shm/{shm_name_x}")
+  if os.path.exists(f"/dev/shm/{shm_name_y}"): os.unlink(f"/dev/shm/{shm_name_y}")
+  shm_x = shared_memory.SharedMemory(name=shm_name_x, create=True, size=prod(sz))
+  shm_y = shared_memory.SharedMemory(name=shm_name_y, create=True, size=prod(sz))
+
+  shutdown = False
+  class Cookie:
+    def __init__(self, bc):
+      self.bc = bc
+    def __del__(self):
+      if not shutdown:
+        try: enqueue_batch(self.bc)
+        except StopIteration: pass
+
+  def enqueue_batch(bc):
+    for idx in range(bc * batch_size, (bc+1) * batch_size):
+      fn = files[next(ds_iter)]
+      queue_in.put((idx, fn, val))
+
+  def shuffle_indices(file_indices, seed=None):
+    rng = random.Random(seed)
+    rng.shuffle(file_indices)
+
+  if shuffle: shuffle_indices(file_indices, seed=seed)
+  ds_iter = iter(file_indices)
+
+  try:
+    X = Tensor.empty(*sz, dtype=dtypes.float32, device=f"disk:/dev/shm/{shm_name_x}")
+    Y = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name_y}")
+
+    for _ in range(cpu_count()):
+      proc = Process(target=load_unet3d_data, args=(preprocessed_dataset_dir, seed, queue_in, queue_out, X, Y))
+      proc.daemon = True
+      proc.start()
+      
+      procs.append(proc)
+
+    for bc in range(batch_count):
+      enqueue_batch(bc)
+
+    for _ in range(len(files) // batch_size):
+      while True:
+        bc = queue_out.get() // batch_size
+        data_out_count[bc] += 1
+        if data_out_count[bc] == batch_size: break
+
+      data_out_count[bc] = 0
+      yield X[bc * batch_size:(bc + 1) * batch_size], Y[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
+  finally:
+    shutdown = True
+
+    for _ in procs: queue_in.put(None)
+    queue_in.close()
+
+    for _ in procs:
+      while queue_out.get() is not None: pass
+    queue_out.close()
+
+    # shutdown processes
+    for proc in procs: proc.join()
+
+    shm_x.close()
+    shm_y.close()
+    try:
+      shm_x.unlink()
+      shm_y.unlink()
+    except FileNotFoundError:
+      # happens with BENCHMARK set
+      pass
+
+if __name__ == "__main__":
+  def load_unet3d(val):
+    assert not val, "validation set is not supported due to different sizes on inputs"
+
+    from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, BASEDIR
+    preprocessed_dataset_dir = (BASEDIR / ".." / "preprocessed" / ("val" if val else "train"))
+    files = get_val_files() if val else get_train_files()
+
+    if not preprocessed_dataset_dir.exists(): preprocess_dataset(files, preprocessed_dataset_dir, val)
+    with tqdm(total=len(files)) as pbar:
+      for x, _, _ in batch_load_unet3d(preprocessed_dataset_dir, val=val):
+        pbar.update(x.shape[0])
+
+  def load_resnet(val):
+    from extra.datasets.imagenet import get_train_files, get_val_files
+    files = get_val_files() if val else get_train_files()
+    with tqdm(total=len(files)) as pbar:
+      for x,y,c in batch_load_resnet(val=val):
+        pbar.update(x.shape[0])
+
+  load_fn_name = f"load_{getenv('MODEL', 'resnet')}"
+  if load_fn_name in globals():
+    globals()[load_fn_name](getenv("VAL", 1))

+ 240 - 0
tinychat/examples/mlperf/helpers.py

@@ -0,0 +1,240 @@
+from collections import OrderedDict
+import unicodedata
+import numpy as np
+from tinygrad.nn import state
+from tinygrad.tensor import Tensor, dtypes
+from tinygrad.helpers import getenv
+
+#
+# checkpointing utils
+#
+
+def invert_dict(d): return {v: k for k, v in reversed(d.items())}
+def dedup_dict(d): return invert_dict(invert_dict(d))
+# store each tensor into the first key it appears in
+def get_training_state(model, optimizer, scheduler):
+  # hack: let get_state_dict walk the tree starting with model, so that the checkpoint keys are
+  # readable and can be loaded as a model for eval
+  train_state = {'model': model, 'optimizer': optimizer, 'scheduler': scheduler}
+  return dedup_dict(state.get_state_dict(train_state))
+def load_training_state(model, optimizer, scheduler, state_dict):
+  # use fresh model to restore duplicate keys
+  train_state = {'model': model, 'optimizer': optimizer, 'scheduler': scheduler}
+  big_dict = state.get_state_dict(train_state)
+  # hack: put back the dupes
+  dupe_names = {}
+  for k, v in big_dict.items():
+    if v not in dupe_names:
+      dupe_names[v] = k
+      assert k in state_dict
+    state_dict[k] = state_dict[dupe_names[v]]
+  # scheduler contains optimizer and all params, load each weight only once
+  scheduler_state = {'scheduler': scheduler}
+  state.load_state_dict(scheduler_state, state_dict)
+
+def gaussian_kernel(n, std):
+  from scipy import signal
+  gaussian_1d = signal.windows.gaussian(n, std)
+  gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
+  gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
+  gaussian_3d = gaussian_3d.reshape(n, n, n)
+  gaussian_3d = np.cbrt(gaussian_3d)
+  gaussian_3d /= gaussian_3d.max()
+  return gaussian_3d
+
+def prepare_arrays(image, roi_shape=(128, 128, 128)):
+  assert len(roi_shape) == 3 and any(roi_shape)
+  image_shape = list(image.shape[2:])
+  result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
+  norm_map = np.zeros_like(result)
+  norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
+  return result, norm_map, norm_patch
+
+def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
+  assert len(roi_shape) == 3 and any(roi_shape)
+  assert 0 < overlap_factor < 1
+  image_shape, dim = list(image.shape[2:]), len(image.shape[2:])
+  strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)]
+  size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
+  for i in range(0, strides[0] * size[0], strides[0]):
+    for j in range(0, strides[1] * size[1], strides[1]):
+      for k in range(0, strides[2] * size[2], strides[2]):
+        yield i, j, k
+
+def _get_best_indices(logits, n_best_size):
+  index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
+  return list(map(lambda x: x[0], index_and_score))[:n_best_size]
+
+def _is_punctuation(char):
+  if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
+    return True
+  return unicodedata.category(char).startswith("P")
+
+def _is_whitespace(char):
+  if char == " " or char == "\t" or char == "\n" or char == "\r":
+    return True
+  return unicodedata.category(char) == "Zs"
+
+def _is_control(char):
+  if char == "\t" or char == "\n" or char == "\r":
+    return False
+  return unicodedata.category(char).startswith("C")
+
+def _run_split_on_punc(text):
+  if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
+    return [text]
+  start_new_word = True
+  output = []
+  for i in range(len(text)):
+    if _is_punctuation(char := text[i]):
+      output.append([char])
+      start_new_word = True
+    else:
+      if start_new_word:
+        output.append([])
+      start_new_word = False
+      output[-1].append(char)
+  return ["".join(x) for x in output]
+
+def _run_strip_accents(text):
+  output = []
+  for char in unicodedata.normalize("NFD", text):
+    if unicodedata.category(char) != "Mn":
+      output.append(char)
+  return "".join(output)
+
+def _clean_text(text):
+  output = []
+  for char in text:
+    if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
+      output.append(" " if _is_whitespace(char) else char)
+  return "".join(output)
+
+def _get_final_text(pred_text, orig_text):
+  def _strip_spaces(text):
+    ns_text = ""
+    ns_to_s_map = OrderedDict()
+    for i, c in enumerate(text):
+      if c == " ":
+        continue
+      ns_to_s_map[len(ns_text)] = i
+      ns_text += c
+    return ns_text, ns_to_s_map
+
+  orig_tokens = _clean_text(orig_text).strip().split()
+  split_tokens = []
+  for token in orig_tokens:
+    if token not in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
+      token = token.lower()
+      token = _run_strip_accents(token)
+    split_tokens.extend(_run_split_on_punc(token))
+
+  tok_text = " ".join(" ".join(split_tokens).strip().split())
+  start_position = tok_text.find(pred_text)
+  if start_position == -1:
+    return orig_text
+  end_position = start_position + len(pred_text) - 1
+
+  orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text)
+  tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text)
+  if len(orig_ns_text) != len(tok_ns_text):
+    return orig_text
+  tok_s_to_ns_map = {v: k for k, v in tok_ns_to_s_map.items()}
+
+  orig_start_position = None
+  if start_position in tok_s_to_ns_map:
+    if (ns_start_position := tok_s_to_ns_map[start_position]) in orig_ns_to_s_map:
+      orig_start_position = orig_ns_to_s_map[ns_start_position]
+  if orig_start_position is None:
+    return orig_text
+
+  orig_end_position = None
+  if end_position in tok_s_to_ns_map:
+    if (ns_end_position := tok_s_to_ns_map[end_position]) in orig_ns_to_s_map:
+      orig_end_position = orig_ns_to_s_map[ns_end_position]
+  if orig_end_position is None:
+    return orig_text
+
+  output_text = orig_text[orig_start_position:(orig_end_position + 1)]
+  return output_text
+
+def get_bert_qa_prediction(features, example, start_end_logits):
+  prelim_predictions = []
+  for i, feature in enumerate(features):
+    for start_index in _get_best_indices(start_end_logits[i][0], 20):
+      for end_index in _get_best_indices(start_end_logits[i][1], 20):
+        if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]):
+          continue
+        if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]:
+          continue
+        if not feature["token_is_max_context"].get(start_index, False):
+          continue
+        if end_index < start_index or end_index - start_index + 1 > 30:
+          continue
+
+        prelim_predictions.append({
+          "feature_index": i,
+          "start_index": start_index,
+          "end_index": end_index,
+          "start_logit": start_end_logits[i][0, start_index],
+          "end_logit": start_end_logits[i][1, end_index]
+        })
+  predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True)
+
+  if len(predictions) > 0:
+    feature = features[predictions[0]["feature_index"]]
+    tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)]
+    orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
+    orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
+    orig_tokens = example["context"][orig_doc_start:(orig_doc_end + 1)]
+    tok_text = " ".join(tok_tokens).replace(" ##", "").replace("##", "")
+    tok_text = " ".join(tok_text.strip().split())
+    orig_text = " ".join(orig_tokens)
+    return _get_final_text(tok_text, orig_text)
+  return "empty"
+
+def get_mlperf_bert_config():
+  """Config is BERT-large"""
+  return {
+    "attention_probs_dropout_prob": 0.1,
+    "hidden_dropout_prob": 0.1,
+    "hidden_size": 1024,
+    "intermediate_size": 4096,
+    "max_position_embeddings": 512,
+    "num_attention_heads": 16,
+    "num_hidden_layers": 24,
+    "type_vocab_size": 2,
+    "vocab_size": 30522
+  }
+
+def get_mlperf_bert_model(checkpoint_path:str=""):
+  from extra.models import bert
+  from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
+
+  bert.Linear = LinearBert
+  bert.Embedding = EmbeddingBert 
+  bert.LayerNorm = LayerNormBert
+
+  from extra.models.bert import BertForPretraining
+  config = get_mlperf_bert_config()
+  if getenv("DISABLE_DROPOUT", 0):
+    config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
+  model = BertForPretraining(**config)
+  if checkpoint_path: model.load_from_pretrained(checkpoint_path)
+  return model
+
+def get_data_bert(GPUS:list[str], it):
+  data: dict[str, Tensor] = next(it)
+  for key in data.keys(): data[key].shard_(GPUS, axis=0)
+  return data
+
+def get_fake_data_bert(GPUS:list[str], BS:int):
+  return {
+    "input_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
+    "input_mask": Tensor.empty((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
+    "segment_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
+    "masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
+    "masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
+    "masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
+    "next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
+  }

+ 68 - 0
tinychat/examples/mlperf/initializers.py

@@ -0,0 +1,68 @@
+import math
+from typing import Union, Tuple
+
+from tinygrad import Tensor, nn, dtypes
+from tinygrad.helpers import prod, argfix
+
+# rejection sampling truncated randn
+def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
+  CNT=8
+  x = Tensor.randn(*(*shape, CNT), dtype=dtype, **kwargs)
+  ctr = Tensor.arange(CNT).reshape((1,) * len(x.shape[:-1]) + (CNT,)).expand(x.shape)
+  take = (x.abs() <= truncstds).where(ctr, CNT).min(axis=-1, keepdim=True)  # set to 0 if no good samples
+  return (ctr == take).where(x, 0).sum(axis=-1)
+
+# https://github.com/keras-team/keras/blob/v2.15.0/keras/initializers/initializers.py#L1026-L1065
+def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor:
+  std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978
+  return std * rand_truncn(*shape, **kwargs)
+
+class Conv2dHeNormal(nn.Conv2d):
+  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+    super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
+    self.in_channels, self.out_channels = in_channels, out_channels  # for testing
+    self.weight = he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0, dtype=dtypes.float32)
+    if bias: self.bias = self.bias.cast(dtypes.float32)
+  def __call__(self, x: Tensor):
+    return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
+                    padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
+
+class Linear(nn.Linear):
+  def __init__(self, in_features, out_features, bias=True):
+    super().__init__(in_features, out_features, bias=bias)
+    self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01, dtype=dtypes.float32)
+    if bias: self.bias = Tensor.zeros(out_features, dtype=dtypes.float32)
+  def __call__(self, x:Tensor):
+    return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
+
+class LinearBert(nn.Linear):
+  def __init__(self, in_features, out_features, bias=True, std=0.02):
+    self.weight = std * rand_truncn(out_features, in_features, dtype=dtypes.float32)
+    self.bias = Tensor.zeros(out_features, dtype=dtypes.float32) if bias else None
+  
+  def __call__(self, x:Tensor):
+    return x.cast(dtypes.default_float).linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
+
+class EmbeddingBert(nn.Embedding):
+  def __init__(self, vocab_size:int, embed_size:int, std=0.02):
+    self.vocab_sz, self.embed_sz = vocab_size, embed_size
+    self.weight = std * rand_truncn(vocab_size, embed_size, dtype=dtypes.float32)
+
+  def __call__(self, idx:Tensor) -> Tensor:
+    if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
+    arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
+    if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
+    arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
+    return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
+
+class LayerNormBert:
+  def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
+    self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
+    self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
+    self.weight, self.bias = (Tensor.ones(*self.normalized_shape, dtype=dtypes.float32), Tensor.zeros(*self.normalized_shape, dtype=dtypes.float32)) if elementwise_affine else (None, None)
+
+  def __call__(self, x:Tensor):
+    assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
+    xn = x.cast(dtypes.float32).layernorm(eps=self.eps, axis=self.axis).cast(x.dtype)
+    if not self.elementwise_affine: return xn
+    return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))

+ 6 - 0
tinychat/examples/mlperf/losses.py

@@ -0,0 +1,6 @@
+from examples.mlperf.metrics import dice_score
+
+def dice_ce_loss(pred, tgt):
+  ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
+  dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
+  return (dice + ce) / 2

+ 22 - 0
tinychat/examples/mlperf/lr_schedulers.py

@@ -0,0 +1,22 @@
+from tinygrad import Tensor, dtypes
+from tinygrad.nn.optim import Optimizer
+
+from extra.lr_scheduler import LR_Scheduler
+
+# https://github.com/mlcommons/training/blob/e237206991d10449d9675d95606459a3cb6c21ad/image_classification/tensorflow2/lars_util.py
+class PolynomialDecayWithWarmup(LR_Scheduler):
+  def __init__(self, optimizer: Optimizer, initial_lr, end_lr, train_steps, warmup, power=2):
+    super().__init__(optimizer)
+    self.epoch_counter = self.epoch_counter.cast(dtypes.float32)
+    assert train_steps > 0 and warmup > 0
+    self.warmup = min(warmup, train_steps)
+    self.initial_lr, self.end_lr, self.epochs, self.power = initial_lr, end_lr, train_steps, power
+
+    # set lr for first warmup step
+    self.optimizer.lr.assign(self.get_lr()).realize()
+
+  def get_lr(self):
+    # LR is 0 on the first step, matching the reference.
+    warmup_lr = (self.epoch_counter * (1.0 / self.warmup)) * self.initial_lr
+    x = (1 - (self.epoch_counter - self.warmup) / (self.epochs - self.warmup + 1))
+    return (self.epoch_counter <= self.warmup).where(warmup_lr, (self.initial_lr - self.end_lr) * x ** self.power + self.end_lr).cast(self.optimizer.lr.dtype)

+ 61 - 0
tinychat/examples/mlperf/metrics.py

@@ -0,0 +1,61 @@
+import re
+import string
+from collections import Counter
+
+def levenshtein(a, b):
+  n, m = len(a), len(b)
+  if n > m:
+    a, b, n, m = b, a, m, n
+
+  current = list(range(n + 1))
+  for i in range(1, m + 1):
+    previous, current = current, [i] + [0] * n
+    for j in range(1, n + 1):
+      add, delete = previous[j] + 1, current[j - 1] + 1
+      change = previous[j - 1]
+      if a[j - 1] != b[i - 1]:
+        change = change + 1
+      current[j] = min(add, delete, change)
+
+  return current[n]
+
+def word_error_rate(x, y):
+  scores = words = 0
+  for h, r in zip(x, y):
+    h_list = h.split()
+    r_list = r.split()
+    words += len(r_list)
+    scores += levenshtein(h_list, r_list)
+  return float(scores) / words, float(scores), words
+
+def one_hot(x):
+  return x.one_hot(3).squeeze(1).permute(0, 4, 1, 2, 3)
+
+def dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6, argmax=True, to_one_hot_x=True):
+  channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
+  if argmax: prediction = prediction.argmax(axis=channel_axis)
+  else: prediction = prediction.softmax(axis=channel_axis)
+  if to_one_hot_x: prediction = one_hot(prediction)
+  target = one_hot(target)
+  prediction, target = prediction[:, 1:], target[:, 1:]
+  assert prediction.shape == target.shape, f"prediction ({prediction.shape}) and target ({target.shape}) shapes do not match"
+  intersection = (prediction * target).sum(axis=reduce_axis)
+  target_sum = target.sum(axis=reduce_axis)
+  prediction_sum = prediction.sum(axis=reduce_axis)
+  result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
+  return result
+
+def normalize_string(s):
+  s = "".join(c for c in s.lower() if c not in string.punctuation)
+  s = re.sub(r'\b(a|an|the)\b', ' ', s)
+  return " ".join(s.split())
+
+def f1_score(x, y):
+  xt = normalize_string(x).split()
+  yt = normalize_string(y).split()
+  ct = Counter(xt) & Counter(yt)
+  if (ns := sum(ct.values())) == 0:
+    return 0.0
+  p = ns / len(xt)
+  r = ns / len(yt)
+  return 2 * p * r / (p + r)

+ 252 - 0
tinychat/examples/mlperf/model_eval.py

@@ -0,0 +1,252 @@
+import time
+start = time.perf_counter()
+from pathlib import Path
+import numpy as np
+from tinygrad import Tensor, Device, dtypes, GlobalCounters, TinyJit
+from tinygrad.nn.state import get_parameters, load_state_dict, safe_load
+from tinygrad.helpers import getenv
+def tlog(x): print(f"{x:25s}  @ {time.perf_counter()-start:5.2f}s")
+
+def eval_resnet():
+  Tensor.no_grad = True
+  # Resnet50-v1.5
+  from extra.models.resnet import ResNet50
+  tlog("imports")
+  GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))]
+  for x in GPUS: Device[x]
+  tlog("got devices")    # NOTE: this is faster with rocm-smi running
+
+  class ResnetRunner:
+    def __init__(self, device=None):
+      self.mdl = ResNet50()
+      for x in get_parameters(self.mdl) if device else []: x.to_(device)
+      if (fn:=getenv("RESNET_MODEL", "")): load_state_dict(self.mdl, safe_load(fn))
+      else: self.mdl.load_from_pretrained()
+      self.input_mean = Tensor([0.485, 0.456, 0.406], device=device).reshape(1, -1, 1, 1)
+      self.input_std = Tensor([0.229, 0.224, 0.225], device=device).reshape(1, -1, 1, 1)
+    def __call__(self, x:Tensor) -> Tensor:
+      x = x.permute([0,3,1,2]).cast(dtypes.float32) / 255.0
+      x -= self.input_mean
+      x /= self.input_std
+      return self.mdl(x).log_softmax().argmax(axis=1).realize()
+
+  mdl = TinyJit(ResnetRunner(GPUS))
+  tlog("loaded models")
+
+  # evaluation on the mlperf classes of the validation set from imagenet
+  from examples.mlperf.dataloader import batch_load_resnet
+  iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False, pad_first_batch=True)
+  def data_get():
+    x,y,cookie = next(iterator)
+    return x.shard(GPUS, axis=0).realize(), y, cookie
+  n,d = 0,0
+  proc = data_get()
+  tlog("loaded initial data")
+  st = time.perf_counter()
+  while proc is not None:
+    GlobalCounters.reset()
+    proc = (mdl(proc[0]), proc[1], proc[2])  # this frees the images
+    run = time.perf_counter()
+    # load the next data here
+    try: next_proc = data_get()
+    except StopIteration: next_proc = None
+    nd = time.perf_counter()
+    y = np.array(proc[1])
+    proc = (proc[0].numpy() == y) & (y != -1)  # this realizes the models and frees the cookies
+    n += proc.sum()
+    d += (y != -1).sum()
+    et = time.perf_counter()
+    tlog(f"****** {n:5d}/{d:5d}  {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
+    st = et
+    proc, next_proc = next_proc, None
+  tlog("done")
+
+def eval_unet3d():
+  # UNet3D
+  from extra.models.unet3d import UNet3D
+  from extra.datasets.kits19 import iterate, sliding_window_inference, get_val_files
+  from examples.mlperf.metrics import dice_score
+  mdl = UNet3D()
+  mdl.load_from_pretrained()
+  s = 0
+  st = time.perf_counter()
+  for i, (image, label) in enumerate(iterate(get_val_files()), start=1):
+    mt = time.perf_counter()
+    pred, label = sliding_window_inference(mdl, image, label)
+    et = time.perf_counter()
+    print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
+    s += dice_score(Tensor(pred), Tensor(label)).mean().item()
+    print(f"****** {s:.2f}/{i}  {s/i:.5f} Mean DICE score")
+    st = time.perf_counter()
+
+def eval_retinanet():
+  # RetinaNet with ResNeXt50_32X4D
+  from extra.models.resnet import ResNeXt50_32X4D
+  from extra.models.retinanet import RetinaNet
+  mdl = RetinaNet(ResNeXt50_32X4D())
+  mdl.load_from_pretrained()
+
+  input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
+  input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
+  def input_fixup(x):
+    x = x.permute([0,3,1,2]) / 255.0
+    x -= input_mean
+    x /= input_std
+    return x
+
+  from extra.datasets.openimages import openimages, iterate
+  from pycocotools.coco import COCO
+  from pycocotools.cocoeval import COCOeval
+  from contextlib import redirect_stdout
+  coco = COCO(openimages('validation'))
+  coco_eval = COCOeval(coco, iouType="bbox")
+  coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
+
+  from tinygrad.engine.jit import TinyJit
+  mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
+
+  n, bs = 0, 8
+  st = time.perf_counter()
+  for x, targets in iterate(coco, bs):
+    dat = Tensor(x.astype(np.float32))
+    mt = time.perf_counter()
+    if dat.shape[0] == bs:
+      outs = mdlrun(dat).numpy()
+    else:
+      mdlrun.jit_cache = None
+      outs =  mdl(input_fixup(dat)).numpy()
+    et = time.perf_counter()
+    predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
+    ext = time.perf_counter()
+    n += len(targets)
+    print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
+    img_ids = [t["image_id"] for t in targets]
+    coco_results  = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box.tolist(), "score": score}
+      for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
+    with redirect_stdout(None):
+      coco_eval.cocoDt = coco.loadRes(coco_results)
+      coco_eval.params.imgIds = img_ids
+      coco_eval.evaluate()
+    evaluated_imgs.extend(img_ids)
+    coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
+    st = time.perf_counter()
+
+  coco_eval.params.imgIds = evaluated_imgs
+  coco_eval._paramsEval.imgIds = evaluated_imgs
+  coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
+  coco_eval.accumulate()
+  coco_eval.summarize()
+
+def eval_rnnt():
+  # RNN-T
+  from extra.models.rnnt import RNNT
+  mdl = RNNT()
+  mdl.load_from_pretrained()
+
+  from extra.datasets.librispeech import iterate
+  from examples.mlperf.metrics import word_error_rate
+
+  LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
+
+  c = 0
+  scores = 0
+  words = 0
+  st = time.perf_counter()
+  for X, Y in iterate():
+    mt = time.perf_counter()
+    tt = mdl.decode(Tensor(X[0]), Tensor([X[1]]))
+    et = time.perf_counter()
+    print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
+    for n, t in enumerate(tt):
+      tnp = np.array(t)
+      _, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
+      scores += scores_
+      words += words_
+    c += len(tt)
+    print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
+    st = time.perf_counter()
+
+def eval_bert():
+  # Bert-QA
+  from extra.models.bert import BertForQuestionAnswering
+  mdl = BertForQuestionAnswering()
+  mdl.load_from_pretrained()
+
+  @TinyJit
+  def run(input_ids, input_mask, segment_ids):
+    return mdl(input_ids, input_mask, segment_ids).realize()
+
+  from extra.datasets.squad import iterate
+  from examples.mlperf.helpers import get_bert_qa_prediction
+  from examples.mlperf.metrics import f1_score
+  from transformers import BertTokenizer
+
+  tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights/bert_vocab.txt"))
+
+  c = 0
+  f1 = 0.0
+  st = time.perf_counter()
+  for X, Y in iterate(tokenizer):
+    mt = time.perf_counter()
+    outs = []
+    for x in X:
+      outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy())
+    et = time.perf_counter()
+    print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features")
+
+    pred = get_bert_qa_prediction(X, Y, outs)
+    print(f"pred: {pred}\nans: {Y['answers']}")
+    f1 += max([f1_score(pred, ans) for ans in Y["answers"]])
+    c += 1
+    print(f"f1: {f1/c}, raw: {f1}, c: {c}\n")
+
+    st = time.perf_counter()
+
+def eval_mrcnn():
+  from tqdm import tqdm
+  from extra.models.mask_rcnn import MaskRCNN
+  from extra.models.resnet import ResNet
+  from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
+  from examples.mask_rcnn import compute_prediction_batched, Image
+  mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
+  mdl.load_from_pretrained()
+
+  bbox_output = '/tmp/results_bbox.json'
+  mask_output = '/tmp/results_mask.json'
+
+  accumulate_predictions_for_coco([], bbox_output, rm=True)
+  accumulate_predictions_for_coco([], mask_output, rm=True)
+
+  #TODO: bs > 1 not as accurate
+  bs = 1
+
+  for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs):
+    batch_imgs = []
+    for image_row in batch:
+      image_name = image_row['file_name']
+      img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
+      batch_imgs.append(img)
+    batch_result = compute_prediction_batched(batch_imgs, mdl)
+    for image_row, result in zip(batch, batch_result):
+      image_name = image_row['file_name']
+      box_pred = convert_prediction_to_coco_bbox(image_name, result)
+      mask_pred = convert_prediction_to_coco_mask(image_name, result)
+      accumulate_predictions_for_coco(box_pred, bbox_output)
+      accumulate_predictions_for_coco(mask_pred, mask_output)
+    del batch_imgs
+    del batch_result
+
+  evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
+  evaluate_predictions_on_coco(mask_output, iou_type='segm')
+
+if __name__ == "__main__":
+  # inference only
+  Tensor.training = False
+  Tensor.no_grad = True
+
+  models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
+  for m in models:
+    nm = f"eval_{m}"
+    if nm in globals():
+      print(f"eval {m}")
+      globals()[nm]()

+ 70 - 0
tinychat/examples/mlperf/model_spec.py

@@ -0,0 +1,70 @@
+# load each model here, quick benchmark
+from tinygrad import Tensor, GlobalCounters
+from tinygrad.helpers import getenv
+import numpy as np
+
+def test_model(model, *inputs):
+  GlobalCounters.reset()
+  out = model(*inputs)
+  if isinstance(out, Tensor): out = out.numpy()
+  # TODO: return event future to still get the time_sum_s without DEBUG=2
+  print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
+
+def spec_resnet():
+  # Resnet50-v1.5
+  from extra.models.resnet import ResNet50
+  mdl = ResNet50()
+  img = Tensor.randn(1, 3, 224, 224)
+  test_model(mdl, img)
+
+def spec_retinanet():
+  # Retinanet with ResNet backbone
+  from extra.models.resnet import ResNet50
+  from extra.models.retinanet import RetinaNet
+  mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
+  img = Tensor.randn(1, 3, 224, 224)
+  test_model(mdl, img)
+
+def spec_unet3d():
+  # 3D UNET
+  from extra.models.unet3d import UNet3D
+  mdl = UNet3D()
+  #mdl.load_from_pretrained()
+  img = Tensor.randn(1, 1, 128, 128, 128)
+  test_model(mdl, img)
+
+def spec_rnnt():
+  from extra.models.rnnt import RNNT
+  mdl = RNNT()
+  #mdl.load_from_pretrained()
+  x = Tensor.randn(220, 1, 240)
+  y = Tensor.randn(1, 220)
+  test_model(mdl, x, y)
+
+def spec_bert():
+  from extra.models.bert import BertForQuestionAnswering
+  mdl = BertForQuestionAnswering()
+  #mdl.load_from_pretrained()
+  x = Tensor.randn(1, 384)
+  am = Tensor.randn(1, 384)
+  tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
+  test_model(mdl, x, am, tt)
+
+def spec_mrcnn():
+  from extra.models.mask_rcnn import MaskRCNN, ResNet
+  mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
+  #mdl.load_from_pretrained()
+  x = Tensor.randn(3, 224, 224)
+  test_model(mdl, [x])
+
+if __name__ == "__main__":
+  # inference only for now
+  Tensor.training = False
+  Tensor.no_grad = True
+
+  for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
+    nm = f"spec_{m}"
+    if nm in globals():
+      print(f"testing {m}")
+      globals()[nm]()
+

+ 691 - 0
tinychat/examples/mlperf/model_train.py

@@ -0,0 +1,691 @@
+import os, time, math, functools
+from pathlib import Path
+from tqdm import tqdm
+import multiprocessing
+
+from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
+from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear
+from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
+from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
+
+from extra.lr_scheduler import LRSchedulerGroup
+from examples.mlperf.helpers import get_training_state, load_training_state
+
+def train_resnet():
+  from extra.models import resnet
+  from examples.mlperf.dataloader import batch_load_resnet
+  from extra.datasets.imagenet import get_train_files, get_val_files
+  from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
+  from examples.mlperf.initializers import Conv2dHeNormal, Linear
+  from examples.hlb_cifar10 import UnsyncedBatchNorm
+
+  config = {}
+  seed = config["seed"] = getenv("SEED", 42)
+  Tensor.manual_seed(seed)  # seed for weight initialization
+
+  INITMLPERF = getenv("INITMLPERF")
+  RUNMLPERF = getenv("RUNMLPERF")
+  if getenv("LOGMLPERF"):
+    from mlperf_logging import mllog
+    import mlperf_logging.mllog.constants as mllog_constants
+    mllog.config(filename=f"result_{seed}.txt")
+    mllog.config(root_dir=Path(__file__).parents[3].as_posix())  # truncate to log this. "file": "tinygrad/examples/mlperf/model_train.py"
+    MLLOGGER = mllog.get_mllogger()
+    if INITMLPERF:
+      # common.yaml
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
+      # closed_common.yaml
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RESNET)
+      diskcache_clear()
+      MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
+      MLLOGGER.start(key=mllog_constants.INIT_START)
+    if RUNMLPERF:
+      MLLOGGER.start(key=mllog_constants.RUN_START)
+      MLLOGGER.event(key=mllog_constants.SEED, value=seed)
+  else:
+    MLLOGGER = None
+
+  GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
+  print(f"training on {GPUS}")
+  for x in GPUS: Device[x]
+
+  TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
+  EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
+
+  # ** model definition and initializers **
+  num_classes = 1000
+  resnet.Conv2d = Conv2dHeNormal
+  resnet.Linear = Linear
+  if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS))
+  model = resnet.ResNet50(num_classes)
+
+  # shard weights and initialize in order
+  for k, x in get_state_dict(model).items():
+    if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k):
+      x.realize().shard_(GPUS, axis=0)
+    else:
+      x.realize().to_(GPUS)
+  parameters = get_parameters(model)
+
+  # ** hyperparameters **
+  epochs            = config["epochs"]            = getenv("EPOCHS", 37)
+  BS                = config["BS"]                = getenv("BS", 104 * len(GPUS))  # fp32 GPUS<=6 7900xtx can fit BS=112
+  EVAL_BS           = config["EVAL_BS"]           = getenv("EVAL_BS", BS)
+  base_lr           = config["base_lr"]           = getenv("LR", 7.2 * (BS/1536))
+  lr_warmup_epochs  = config["lr_warmup_epochs"]  = getenv("WARMUP_EPOCHS", 2)
+  decay             = config["decay"]             = getenv("DECAY", 2e-4)
+
+  loss_scaler       = config["LOSS_SCALER"]       = getenv("LOSS_SCALER", 128.0 if dtypes.default_float == dtypes.float16 else 1.0)
+
+  target, achieved  = getenv("TARGET", 0.759), False
+  eval_start_epoch  = getenv("EVAL_START_EPOCH", 0)
+  eval_freq         = getenv("EVAL_FREQ", 1)
+
+  steps_in_train_epoch  = config["steps_in_train_epoch"]  = (round_up(len(get_train_files()), BS) // BS)
+  steps_in_val_epoch    = config["steps_in_val_epoch"]    = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS)
+
+  config["DEFAULT_FLOAT"] = dtypes.default_float.name
+  config["BEAM"]          = BEAM.value
+  config["TRAIN_BEAM"]    = TRAIN_BEAM
+  config["EVAL_BEAM"]     = EVAL_BEAM
+  config["WINO"]          = WINO.value
+  config["SYNCBN"]        = getenv("SYNCBN")
+
+  # ** Optimizer **
+  skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
+  parameters = [x for x in parameters if x not in set(skip_list)]
+  optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
+  optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
+  optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
+
+  # ** LR scheduler **
+  scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
+                                        train_steps=epochs * steps_in_train_epoch,
+                                        warmup=lr_warmup_epochs * steps_in_train_epoch)
+  scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
+                                             train_steps=epochs * steps_in_train_epoch,
+                                             warmup=lr_warmup_epochs * steps_in_train_epoch)
+  scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
+  print(f"training with batch size {BS} for {epochs} epochs")
+
+  # log mlperf hparams
+  if MLLOGGER:
+    if RUNMLPERF:
+      MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=BS)
+      from extra.datasets.imagenet import get_train_files, get_val_files
+      MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=len(get_train_files()))
+      MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=len(get_val_files()))
+
+      MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
+      MLLOGGER.event(key=mllog_constants.OPT_NAME, value="lars")
+      assert scheduler.initial_lr == scheduler_skip.initial_lr
+      assert scheduler.end_lr == scheduler_skip.end_lr
+      assert scheduler.power == scheduler_skip.power
+      MLLOGGER.event(key=mllog_constants.LARS_OPT_BASE_LEARNING_RATE, value=scheduler.initial_lr)
+      MLLOGGER.event(key=mllog_constants.LARS_OPT_END_LR, value=scheduler.end_lr)
+      MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_POLY_POWER, value=scheduler.power)
+      MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_STEPS, value=epochs)
+      MLLOGGER.event(key=mllog_constants.LARS_EPSILON, value=0)  # does not support epsilon != 0
+      MLLOGGER.event(key=mllog_constants.LARS_OPT_LEARNING_RATE_WARMUP_EPOCHS, value=lr_warmup_epochs)
+      MLLOGGER.event(key=mllog_constants.LARS_OPT_MOMENTUM, value=optimizer.momentum)
+      MLLOGGER.event(key=mllog_constants.LARS_OPT_WEIGHT_DECAY, value=optimizer.wd)
+
+  # ** resume from checkpointing **
+  start_epoch = 0
+  if ckpt:=getenv("RESUME", ""):
+    load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
+    start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
+    print(f"resuming from {ckpt} at epoch {start_epoch}")
+
+  # ** init wandb **
+  WANDB = getenv("WANDB")
+  if WANDB:
+    import wandb
+    wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
+    wandb.init(config=config, **wandb_args)
+
+  BENCHMARK = getenv("BENCHMARK")
+
+  # ** jitted steps **
+  input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
+  # mlperf reference resnet does not divide by input_std for some reason
+  # input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
+  def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
+  @TinyJit
+  def train_step(X, Y):
+    optimizer_group.zero_grad()
+    X = normalize(X)
+    out = model.forward(X)
+    loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
+    top_1 = (out.argmax(-1) == Y).sum()
+    (loss * loss_scaler).backward()
+    for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
+    optimizer_group.step()
+    scheduler_group.step()
+    return loss.realize(), top_1.realize()
+
+  @TinyJit
+  def eval_step(X, Y):
+    X = normalize(X)
+    out = model.forward(X)
+    loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
+    top_1 = (out.argmax(-1) == Y).sum()
+    return loss.realize(), top_1.realize()
+
+  def fake_data_get(batch_size):
+    x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous()
+    y = [0] * batch_size
+    return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None
+
+  def data_get(it):
+    x, y, cookie = next(it)
+    return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie
+
+  # ** epoch loop **
+  step_times = []
+  for e in range(start_epoch, epochs):
+    # ** train loop **
+    if MLLOGGER and RUNMLPERF:
+      MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
+    Tensor.training = True
+    BEAM.value = TRAIN_BEAM
+
+    if INITMLPERF:
+      i, proc = 0, fake_data_get(BS)
+    else:
+      batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True)
+      it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
+      i, proc = 0, data_get(it)
+
+    prev_cookies = []
+    st = time.perf_counter()
+    while proc is not None:
+      GlobalCounters.reset()
+      (loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
+
+      pt = time.perf_counter()
+
+      if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = []  # free previous cookies after gpu work has been enqueued
+      try:
+        if INITMLPERF:
+          next_proc = fake_data_get(BS)
+        else:
+          next_proc = data_get(it)
+      except StopIteration:
+        next_proc = None
+
+      dt = time.perf_counter()
+
+      device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
+      loss, top_1 = loss.numpy().item(), top_1.numpy().item()
+      top_1_acc = top_1 / sum(yi != -1 for yi in y)
+
+      cl = time.perf_counter()
+      if BENCHMARK:
+        step_times.append(cl - st)
+
+      tqdm.write(
+        f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
+        f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, "
+        f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
+      if WANDB:
+        wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st,
+                   "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
+                   "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
+
+      st = cl
+      prev_cookies.append(proc)
+      proc, next_proc = next_proc, None  # return old cookie
+      i += 1
+
+      if i == BENCHMARK:
+        assert not math.isnan(loss)
+        median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2]  # in seconds
+        estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
+        print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
+        print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
+              f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
+        # if we are doing beam search, run the first eval too
+        if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
+        return
+    if MLLOGGER and RUNMLPERF:
+      MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e+1, metadata=dict(epoch_num=e+1))
+
+    # ** eval loop **
+    # always eval for epoch >= 33 to stop the clock as soon as eval target hits, it can converge in epoch in [33, 37]
+    if steps_in_val_epoch > 0 and ((e + 1 - eval_start_epoch) % eval_freq == 0 or e + 1 >= 33):
+      if MLLOGGER and RUNMLPERF:
+        MLLOGGER.start(key=mllog_constants.EVAL_START, value=e+1, metadata=dict(epoch_num=e+1))
+      if getenv("RESET_STEP", 1): train_step.reset()  # free the train step memory :(
+      eval_times = []
+      eval_loss = 0.0
+      eval_top_1 = 0
+      eval_num_samples = 0
+      Tensor.training = False
+      BEAM.value = EVAL_BEAM
+
+      if INITMLPERF:
+        i, proc = 0, fake_data_get(EVAL_BS)
+      else:
+        it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
+        i, proc = 0, data_get(it)
+        
+      prev_cookies = []
+      while proc is not None:
+        GlobalCounters.reset()
+        st = time.time()
+
+        (loss, top_1), y, proc = eval_step(proc[0], proc[1]), proc[2], proc[3]  # drop inputs, keep cookie
+
+        if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = []  # free previous cookies after gpu work has been enqueued
+        try:
+          if INITMLPERF:
+            next_proc = fake_data_get(EVAL_BS)
+          else:
+            next_proc = data_get(it)
+        except StopIteration:
+          next_proc = None
+
+        loss, top_1 = loss.numpy().item(), top_1.numpy().item()
+        num_samples = sum(yi != -1 for yi in y)
+        eval_loss += loss * num_samples
+        eval_top_1 += top_1
+        eval_num_samples += num_samples
+        prev_cookies.append(proc)
+        proc, next_proc = next_proc, None
+        i += 1
+        if i == BENCHMARK:
+          # assume INITMLPERF has BENCHMARK set
+          if MLLOGGER and INITMLPERF:
+            MLLOGGER.event(key=mllog_constants.INIT_STOP)
+          return
+
+        et = time.time()
+        eval_times.append(et - st)
+
+      if getenv("RESET_STEP", 1): eval_step.reset()
+      if not BENCHMARK:
+        assert eval_num_samples == len(get_val_files()), f"eval sample count mismatched. {eval_num_samples=} != {len(get_val_files())}"
+      total_loss = eval_loss / eval_num_samples
+      total_top_1 = eval_top_1 / eval_num_samples
+      total_fw_time = sum(eval_times) / len(eval_times)
+      tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
+      if WANDB:
+        wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
+      if MLLOGGER and RUNMLPERF:
+        MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=total_top_1, metadata=dict(epoch_num=e+1))
+        MLLOGGER.event(key=mllog_constants.EVAL_STOP, value=e+1, metadata=dict(epoch_num=e+1))
+
+      # save model if achieved target
+      if not achieved and total_top_1 >= target:
+        # stop once achieve the target
+        if MLLOGGER and RUNMLPERF:
+          MLLOGGER.event(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
+        if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
+        fn = f"./ckpts/resnet50_{seed}.safe"
+        safe_save(get_state_dict(model), fn)
+        print(f" *** Model saved to {fn} ***")
+        achieved = True
+        break
+
+      # checkpoint every time we eval
+      if getenv("CKPT"):
+        if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
+        if WANDB and wandb.run is not None:
+          fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe"
+        else:
+          fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
+        print(f"saving ckpt to {fn}")
+        safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
+
+def train_retinanet():
+  # TODO: Retinanet
+  pass
+
+def train_unet3d():
+  # TODO: Unet3d
+  pass
+
+def train_rnnt():
+  # TODO: RNN-T
+  pass
+
+@TinyJit
+def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
+  optimizer.zero_grad()
+
+  lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
+  loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
+  (loss * loss_scaler).backward()
+
+  global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
+  for p in optimizer.params: 
+    p.grad = p.grad / loss_scaler
+    global_norm += p.grad.float().square().sum()
+  global_norm = global_norm.sqrt()
+  for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
+
+  optimizer.step()
+  scheduler.step()
+  return loss.realize()
+
+@TinyJit
+def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
+  lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
+  masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
+  return {
+    "masked_lm_accuracy": masked_lm_accuracy.realize(),
+    "next_sentence_accuracy": seq_relationship_accuracy.realize(),
+    "masked_lm_loss": masked_lm_loss.realize(),
+    "next_sentence_loss": next_sentence_loss.realize()
+  }
+
+def train_bert():
+  # NOTE: pip install tensorflow, wandb required
+  from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
+  from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert, get_fake_data_bert
+  from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
+
+  config = {}
+  BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
+
+  GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
+  print(f"training on {GPUS}")
+  for x in GPUS: Device[x]
+  seed = config["seed"] = getenv("SEED", 12345)
+
+  INITMLPERF = getenv("INITMLPERF")
+  RUNMLPERF = getenv("RUNMLPERF")
+  if getenv("LOGMLPERF"):
+    from mlperf_logging import mllog
+    import mlperf_logging.mllog.constants as mllog_constants
+
+    mllog.config(filename="bert.log")
+    mllog.config(root_dir=Path(__file__).parents[3].as_posix())
+    MLLOGGER = mllog.get_mllogger()
+    MLLOGGER.logger.propagate = False
+
+    if INITMLPERF:
+      assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
+
+      MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.BERT)
+
+      diskcache_clear()
+      MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
+      MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
+
+    if RUNMLPERF:
+      MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
+  else:
+    MLLOGGER = None
+
+  # ** hyperparameters **
+  BS                 = config["GLOBAL_BATCH_SIZE"]      = getenv("BS", 16 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
+  EVAL_BS            = config["EVAL_BS"]                = getenv("EVAL_BS", 1 * len(GPUS))
+  max_lr             = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00035 * math.sqrt(BS/256))
+
+  train_steps        = config["TRAIN_STEPS"]            = getenv("TRAIN_STEPS", 4800000 // BS)
+  warmup_steps       = config["NUM_WARMUP_STEPS"]       = getenv("NUM_WARMUP_STEPS", 1)
+  max_eval_steps     = config["MAX_EVAL_STEPS"]         = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
+  eval_step_freq     = config["EVAL_STEP_FREQ"]         = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
+  save_ckpt_freq     = config["SAVE_CKPT_FREQ"]         = getenv("SAVE_CKPT_FREQ", 1000)
+  keep_ckpt_amount   = config["KEEP_CKPT_AMOUNT"]       = getenv("KEEP_CKPT_AMOUNT", 5)
+  save_ckpt_dir      = config["SAVE_CKPT_DIR"]          = getenv("SAVE_CKPT_DIR", "./ckpts")
+  init_ckpt          = config["INIT_CKPT_DIR"]          = getenv("INIT_CKPT_DIR", BASEDIR)
+
+  loss_scaler        = config["LOSS_SCALER"]            = getenv("LOSS_SCALER", 2.0**9 if dtypes.default_float == dtypes.float16 else 1.0)
+  decay              = config["DECAY"]                  = getenv("DECAY", 0.01)
+  epsilon            = config["EPSILON"]                = getenv("EPSILON", 1e-6)
+  poly_power         = config["POLY_POWER"]             = getenv("POLY_POWER", 1.0)
+
+  target, achieved                                      = getenv("TARGET", 0.72), False
+
+  config["DEFAULT_FLOAT"] = dtypes.default_float.name
+  config["DISABLE_DROPOUT"] = getenv("DISABLE_DROPOUT", 0)
+  config["TRAIN_BEAM"]    = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
+  config["EVAL_BEAM"]     = EVAL_BEAM  = getenv("EVAL_BEAM", BEAM.value)
+
+  Tensor.manual_seed(seed)  # seed for weight initialization
+
+  model = get_mlperf_bert_model(init_ckpt)
+  
+  for _, x in get_state_dict(model).items():
+    x.realize().to_(GPUS)
+  parameters = get_parameters(model)
+
+  assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
+
+  # ** Log run config **
+  for key, value in config.items(): print(f'HParam: "{key}": {value}')
+
+  # ** Optimizer **
+  parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
+  parameters = [x for x in parameters if x not in set(parameters_no_wd)]
+  optimizer_wd = LAMB(parameters, lr=max_lr, eps=epsilon, weight_decay=decay, adam=False)
+  optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=epsilon, weight_decay=0.0, adam=False)
+  optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
+
+  # ** LR scheduler **
+  scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
+  scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
+  scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
+  print(f"training with batch size {BS} for one epoch with {train_steps} steps")
+
+  # log mlperf hparams
+  if MLLOGGER:
+    if RUNMLPERF:
+      MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["GLOBAL_BATCH_SIZE"])
+      MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=512)
+      MLLOGGER.event(key="max_predictions_per_seq", value=76)
+
+      MLLOGGER.event(key=mllog_constants.OPT_NAME, value="LAMB")
+      MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["OPT_BASE_LEARNING_RATE"])
+      MLLOGGER.event(key=mllog_constants.OPT_LAMB_WEIGHT_DECAY, value=config["DECAY"])
+      MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_1, value=optimizer_wd.b1)
+      MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_2, value=optimizer_wd.b2)
+      MLLOGGER.event(key=mllog_constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=config["POLY_POWER"])
+      MLLOGGER.event(key=mllog_constants.OPT_LAMB_EPSILON, value=config["EPSILON"])
+
+      MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
+      MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
+      MLLOGGER.event(key='start_warmup_step', value=0)
+      MLLOGGER.event(key='opt_learning_rate_training_steps', value=config["TRAIN_STEPS"])
+      MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
+      MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["EVAL_BS"] * config["MAX_EVAL_STEPS"])
+      MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["GLOBAL_BATCH_SIZE"] * config["TRAIN_STEPS"])
+
+  # ** resume from checkpointing **
+  start_step = 1
+  previous_step = None
+  if ckpt:=getenv("RESUME", ""):
+    load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
+    start_step = int(scheduler_wd.epoch_counter.numpy().item())
+    print(f"resuming from {ckpt} at step {start_step}")
+
+  # ** init wandb **
+  WANDB = getenv("WANDB")
+  if WANDB:
+    import wandb
+    wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
+    wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
+
+  BENCHMARK = getenv("BENCHMARK")
+
+  if not INITMLPERF:
+    eval_it = iter(batch_load_val_bert(EVAL_BS))
+    train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
+
+  step_times = []
+  # ** train loop **
+  wc_start = time.perf_counter()
+  if INITMLPERF:
+    i, train_data = start_step, get_fake_data_bert(GPUS, BS)
+  else:
+    i, train_data = start_step, get_data_bert(GPUS, train_it)
+  while train_data is not None and i < train_steps and not achieved:
+    Tensor.training = True
+    BEAM.value = TRAIN_BEAM
+    st = time.perf_counter()
+    GlobalCounters.reset()
+    loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
+      train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
+      train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
+
+    pt = time.perf_counter()
+
+    try:
+      if INITMLPERF:
+        next_data = get_fake_data_bert(GPUS, BS)
+      else:
+        next_data = get_data_bert(GPUS, train_it)
+    except StopIteration:
+      next_data = None
+
+    dt = time.perf_counter()
+
+    device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
+    loss = loss.numpy().item()
+
+    cl = time.perf_counter()
+    if BENCHMARK: step_times.append(cl - st)
+
+    tqdm.write(
+      f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
+      f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
+      f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
+    if WANDB:
+      wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
+                  "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
+                  "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
+
+    train_data, next_data = next_data, None
+    i += 1
+
+    if i == BENCHMARK:
+      median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2]  # in seconds
+      estimated_total_minutes = int(median_step_time * train_steps / 60)
+      print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
+      print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
+            f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
+
+    # ** eval loop **
+    if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
+      if MLLOGGER and RUNMLPERF:
+        MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i})
+      train_step_bert.reset()
+      eval_lm_losses = []
+      eval_clsf_losses = []
+      eval_lm_accs = []
+      eval_clsf_accs = []
+      eval_times = []
+      Tensor.training = False
+      BEAM.value = EVAL_BEAM
+
+      for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
+        if INITMLPERF:
+          eval_data = get_fake_data_bert(GPUS, EVAL_BS)
+        else:
+          eval_data = get_data_bert(GPUS, eval_it)
+        GlobalCounters.reset()
+        st = time.time()
+
+        eval_result: dict[str, Tensor] = eval_step_bert(model,
+          eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"],
+          eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
+
+        lm_loss, clsf_loss  = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item()
+        lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item()
+
+        eval_lm_losses.append(lm_loss)
+        eval_clsf_losses.append(clsf_loss)
+        eval_lm_accs.append(lm_acc)
+        eval_clsf_accs.append(clsf_acc)
+
+        et = time.time()
+        eval_times.append(et - st)
+
+        if BENCHMARK and j == BENCHMARK:
+          # assume INITMLPERF has BENCHMARK set
+          if MLLOGGER and INITMLPERF:
+            MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None)
+          return
+
+      eval_step_bert.reset()
+      avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
+      avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
+      avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs)
+      avg_clsf_acc = sum(eval_clsf_accs) / len(eval_clsf_accs)
+      avg_fw_time = sum(eval_times) / len(eval_times)
+      results = f"eval lm loss: {avg_lm_loss:.2f}, eval clsf loss: {avg_clsf_loss:.2f}, eval lm accuracy: {avg_lm_acc:.6f}, \
+                  eval clsf accuracy: {avg_clsf_acc:.2f}, avg eval step time: {avg_fw_time:.2f}"
+      tqdm.write(results)
+
+      if WANDB:
+        wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
+                    "eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
+
+      if MLLOGGER and RUNMLPERF:
+        MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i, metadata={"epoch_count": 1, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
+        MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": 1, "masked_lm_accuracy": avg_lm_acc})
+
+      # save model if achieved target
+      if not achieved and avg_lm_acc >= target:
+        wc_end = time.perf_counter()
+        if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
+        fn = f"{ckpt_dir}/bert-large.safe"
+        safe_save(get_state_dict(model), fn)
+        print(f" *** Model saved to {fn} ***")
+
+        total_seconds = wc_end - wc_start
+        hours = int(total_seconds // 3600)
+        minutes = int((total_seconds % 3600) // 60)
+        seconds = total_seconds % 60
+        print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
+        achieved = True
+        if MLLOGGER and RUNMLPERF:
+          MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
+        # stop once hitting the target
+        break
+
+    if getenv("CKPT", 1) and i % save_ckpt_freq == 0:
+      if MLLOGGER and RUNMLPERF:
+        if previous_step:
+          MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
+        MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
+      if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
+      if WANDB and wandb.run is not None:
+        fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
+      else:
+        fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
+      print(f"saving ckpt to {fn}")
+      safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
+      ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
+      ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
+      while len(ckpt_files) > keep_ckpt_amount:
+        last = ckpt_files.pop(0)
+        print(f"Removing old ckpt {last}")
+        os.remove(os.path.join(ckpt_dir, last))
+      if MLLOGGER and RUNMLPERF:
+        MLLOGGER.end(key="checkpoint_stop", value=None, metadata={"step_num": i})
+        MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"], "step_num": i, "first_step_num": i+1})
+        previous_step = i
+
+def train_maskrcnn():
+  # TODO: Mask RCNN
+  pass
+
+if __name__ == "__main__":
+  multiprocessing.set_start_method('spawn')
+  with Tensor.train():
+    for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
+      nm = f"train_{m}"
+      if nm in globals():
+        print(f"training {m}")
+        globals()[nm]()

+ 50 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/README.md

@@ -0,0 +1,50 @@
+# 1. Problem
+
+This problem uses the ResNet-50 CNN to do image classification.
+
+## Requirements
+
+Install tinygrad and mlperf-logging from master.
+```
+git clone https://github.com/tinygrad/tinygrad.git
+python3 -m pip install -e ".[mlperf]"
+```
+
+### tinybox_green
+Install the p2p driver per [README](https://github.com/tinygrad/open-gpu-kernel-modules/blob/550.54.15-p2p/README.md)
+This is the default on production tinybox green.
+
+### tinybox_red
+Disable cwsr
+This is the default on production tinybox red.
+```
+sudo vi /etc/modprobe.d/amdgpu.conf
+cat <<EOF > /etc/modprobe.d/amdgpu.conf
+options amdgpu cwsr_enable=0
+EOF
+sudo update-initramfs -u
+sudo reboot
+
+# validate
+sudo cat /sys/module/amdgpu/parameters/cwsr_enable #= 0
+```
+
+# 2. Directions
+
+## Steps to download and verify data
+
+```
+IMGNET_TRAIN=1 python3 extra/datasets/imagenet_download.py
+```
+
+## Steps for one time setup
+
+### tinybox_red
+```
+examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/setup.sh
+```
+
+## Steps to run benchmark
+```
+examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/run_and_time.sh
+```

+ 13 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/dev_beam.sh

@@ -0,0 +1,13 @@
+#!/bin/bash
+
+export PYTHONPATH="."
+export MODEL="resnet"
+export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=192
+
+export LAZYCACHE=0 RESET_STEP=0
+
+export TRAIN_BEAM=4 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=10 BEAM_PADTO=0
+
+export BENCHMARK=10 DEBUG=2
+
+python3 examples/mlperf/model_train.py

+ 15 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/dev_run.sh

@@ -0,0 +1,15 @@
+#!/bin/bash
+
+export PYTHONPATH="."
+export MODEL="resnet"
+export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=192
+
+export LAZYCACHE=0 RESET_STEP=0
+
+export TRAIN_BEAM=4 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=10 BEAM_PADTO=0
+
+export EVAL_START_EPOCH=3 EVAL_FREQ=4
+
+export WANDB=1 PARALLEL=0
+
+python3 examples/mlperf/model_train.py

+ 23 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_green/run_and_time.sh

@@ -0,0 +1,23 @@
+#!/bin/bash
+
+export PYTHONPATH="."
+export MODEL="resnet"
+export SUBMISSION_PLATFORM="tinybox_green"
+export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=192
+
+export LAZYCACHE=0 RESET_STEP=0
+
+export TRAIN_BEAM=4 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=10 BEAM_PADTO=0
+
+# pip install -e ".[mlperf]"
+export LOGMLPERF=1
+
+export SEED=$RANDOM
+DATETIME=$(date "+%m%d%H%M")
+LOGFILE="resnet_green_${DATETIME}_${SEED}.log"
+
+# init
+BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
+
+# run
+PARALLEL=0 RUNMLPERF=1 EVAL_START_EPOCH=3 EVAL_FREQ=4 python3 examples/mlperf/model_train.py | tee -a $LOGFILE

+ 50 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/README.md

@@ -0,0 +1,50 @@
+# 1. Problem
+
+This problem uses the ResNet-50 CNN to do image classification.
+
+## Requirements
+
+Install tinygrad and mlperf-logging from master.
+```
+git clone https://github.com/tinygrad/tinygrad.git
+python3 -m pip install -e ".[mlperf]"
+```
+
+### tinybox_green
+Install the p2p driver per [README](https://github.com/tinygrad/open-gpu-kernel-modules/blob/550.54.15-p2p/README.md)
+This is the default on production tinybox green.
+
+### tinybox_red
+Disable cwsr
+This is the default on production tinybox red.
+```
+sudo vi /etc/modprobe.d/amdgpu.conf
+cat <<EOF > /etc/modprobe.d/amdgpu.conf
+options amdgpu cwsr_enable=0
+EOF
+sudo update-initramfs -u
+sudo reboot
+
+# validate
+sudo cat /sys/module/amdgpu/parameters/cwsr_enable #= 0
+```
+
+# 2. Directions
+
+## Steps to download and verify data
+
+```
+IMGNET_TRAIN=1 python3 extra/datasets/imagenet_download.py
+```
+
+## Steps for one time setup
+
+### tinybox_red
+```
+examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/setup.sh
+```
+
+## Steps to run benchmark
+```
+examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/run_and_time.sh
+```

+ 13 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/dev_beam.sh

@@ -0,0 +1,13 @@
+#!/bin/bash
+
+export PYTHONPATH="."
+export MODEL="resnet"
+export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=192
+
+export LAZYCACHE=0 RESET_STEP=0
+
+export TRAIN_BEAM=4 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=96 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=0
+
+export BENCHMARK=10 DEBUG=2
+
+python3 examples/mlperf/model_train.py

+ 15 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/dev_run.sh

@@ -0,0 +1,15 @@
+#!/bin/bash
+
+export PYTHONPATH="."
+export MODEL="resnet"
+export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=192
+
+export LAZYCACHE=0 RESET_STEP=0
+
+export TRAIN_BEAM=4 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=96 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=0
+
+export EVAL_START_EPOCH=3 EVAL_FREQ=4
+
+export WANDB=1 PARALLEL=0
+
+python3 examples/mlperf/model_train.py

+ 23 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/run_and_time.sh

@@ -0,0 +1,23 @@
+#!/bin/bash
+
+export PYTHONPATH="."
+export MODEL="resnet"
+export SUBMISSION_PLATFORM="tinybox_red"
+export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=192
+
+export LAZYCACHE=0 RESET_STEP=0
+
+export TRAIN_BEAM=4 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=96 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=0
+
+# pip install -e ".[mlperf]"
+export LOGMLPERF=1
+
+export SEED=$RANDOM
+DATETIME=$(date "+%m%d%H%M")
+LOGFILE="resnet_red_${DATETIME}_${SEED}.log"
+
+# init
+BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
+
+# run
+PARALLEL=0 RUNMLPERF=1 EVAL_START_EPOCH=3 EVAL_FREQ=4 python3 examples/mlperf/model_train.py | tee -a $LOGFILE

+ 8 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/benchmarks/resnet/implementations/tinybox_red/setup.sh

@@ -0,0 +1,8 @@
+#!/bin/bash
+
+rocm-smi --setprofile compute
+rocm-smi --setmclk 3
+rocm-smi --setperflevel high
+
+# power cap to 350W
+echo "350000000" | sudo tee /sys/class/drm/card{1..6}/device/hwmon/hwmon*/power1_cap

+ 38 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/systems/tinybox_green.json

@@ -0,0 +1,38 @@
+{
+  "submitter": "tinycorp",
+  "division": "closed",
+  "system_type": "datacenter",
+  "status": "available",
+  "system_name": "tinybox green",
+  "number_of_nodes": "1",
+  "host_processors_per_node": "1",
+  "host_processor_model_name": "AMD EPYC 7532 32-Core Processor",
+  "host_processor_core_count": "64",
+  "host_processor_frequency": "",
+  "host_processor_caches": "",
+  "host_processor_interconnect": "",
+  "host_memory_capacity": "128GB",
+  "host_storage_type": "NVMe SSD",
+  "host_storage_capacity": "4 TB raid array + 1 TB boot",
+  "host_networking": "",
+  "host_networking_topology": "",
+  "host_memory_configuration": "8x 16GB DDR4",
+  "accelerators_per_node": "6",
+  "accelerator_model_name": "NVIDIA GeForce RTX 4090",
+  "accelerator_host_interconnect": "PCIe 4.0 x16",
+  "accelerator_frequency": "",
+  "accelerator_on-chip_memories": "",
+  "accelerator_memory_configuration": "GDDR6X",
+  "accelerator_memory_capacity": "24GB",
+  "accelerator_interconnect": "",
+  "accelerator_interconnect_topology": "",
+  "cooling": "air",
+  "hw_notes": "",
+  "framework": "tinygrad, commit 0e8aa0e2886bf9a2d3ce093bce87305e182e6d4a",
+  "other_software_stack": {
+    "python": "3.10.12",
+    "CUDA": "12.4"
+  },
+  "operating_system": "Ubuntu 22.04.4",
+  "sw_notes": ""
+}

+ 38 - 0
tinychat/examples/mlperf/training_submission_v4.0/tinycorp/systems/tinybox_red.json

@@ -0,0 +1,38 @@
+{
+  "submitter": "tinycorp",
+  "division": "closed",
+  "system_type": "datacenter",
+  "status": "available",
+  "system_name": "tinybox red",
+  "number_of_nodes": "1",
+  "host_processors_per_node": "1",
+  "host_processor_model_name": "AMD EPYC 7532 32-Core Processor",
+  "host_processor_core_count": "64",
+  "host_processor_frequency": "",
+  "host_processor_caches": "",
+  "host_processor_interconnect": "",
+  "host_memory_capacity": "128GB",
+  "host_storage_type": "NVMe SSD",
+  "host_storage_capacity": "4 TB raid array + 1 TB boot",
+  "host_networking": "",
+  "host_networking_topology": "",
+  "host_memory_configuration": "8x 16GB DDR4",
+  "accelerators_per_node": "6",
+  "accelerator_model_name": "AMD Radeon RX 7900 XTX",
+  "accelerator_host_interconnect": "PCIe 4.0 x16",
+  "accelerator_frequency": "",
+  "accelerator_on-chip_memories": "",
+  "accelerator_memory_configuration": "GDDR6",
+  "accelerator_memory_capacity": "24GB",
+  "accelerator_interconnect": "",
+  "accelerator_interconnect_topology": "",
+  "cooling": "air",
+  "hw_notes": "",
+  "framework": "tinygrad, commit 0e8aa0e2886bf9a2d3ce093bce87305e182e6d4a",
+  "other_software_stack": {
+    "python": "3.10.12",
+    "ROCm": "6.1"
+  },
+  "operating_system": "Ubuntu 22.04.4",
+  "sw_notes": ""
+}

+ 107 - 0
tinychat/examples/mnist_gan.py

@@ -0,0 +1,107 @@
+from pathlib import Path
+import numpy as np
+import torch
+from torchvision.utils import make_grid, save_image
+from tinygrad.nn.state import get_parameters
+from tinygrad.tensor import Tensor
+from tinygrad.helpers import trange
+from tinygrad.nn import optim
+from extra.datasets import fetch_mnist
+
+class LinearGen:
+  def __init__(self):
+    self.l1 = Tensor.scaled_uniform(128, 256)
+    self.l2 = Tensor.scaled_uniform(256, 512)
+    self.l3 = Tensor.scaled_uniform(512, 1024)
+    self.l4 = Tensor.scaled_uniform(1024, 784)
+
+  def forward(self, x):
+    x = x.dot(self.l1).leakyrelu(0.2)
+    x = x.dot(self.l2).leakyrelu(0.2)
+    x = x.dot(self.l3).leakyrelu(0.2)
+    x = x.dot(self.l4).tanh()
+    return x
+
+class LinearDisc:
+  def __init__(self):
+    self.l1 = Tensor.scaled_uniform(784, 1024)
+    self.l2 = Tensor.scaled_uniform(1024, 512)
+    self.l3 = Tensor.scaled_uniform(512, 256)
+    self.l4 = Tensor.scaled_uniform(256, 2)
+
+  def forward(self, x):
+    # balance the discriminator inputs with const bias (.add(1))
+    x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3)
+    x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3)
+    x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3)
+    x = x.dot(self.l4).log_softmax()
+    return x
+
+def make_batch(images):
+  sample = np.random.randint(0, len(images), size=(batch_size))
+  image_b = images[sample].reshape(-1, 28*28).astype(np.float32) / 127.5 - 1.0
+  return Tensor(image_b)
+
+def make_labels(bs, col, val=-2.0):
+  y = np.zeros((bs, 2), np.float32)
+  y[range(bs), [col] * bs] = val  # Can we do label smoothing? i.e -2.0 changed to -1.98789.
+  return Tensor(y)
+
+def train_discriminator(optimizer, data_real, data_fake):
+  real_labels = make_labels(batch_size, 1)
+  fake_labels = make_labels(batch_size, 0)
+  optimizer.zero_grad()
+  output_real = discriminator.forward(data_real)
+  output_fake = discriminator.forward(data_fake)
+  loss_real = (output_real * real_labels).mean()
+  loss_fake = (output_fake * fake_labels).mean()
+  loss_real.backward()
+  loss_fake.backward()
+  optimizer.step()
+  return (loss_real + loss_fake).numpy()
+
+def train_generator(optimizer, data_fake):
+  real_labels = make_labels(batch_size, 1)
+  optimizer.zero_grad()
+  output = discriminator.forward(data_fake)
+  loss = (output * real_labels).mean()
+  loss.backward()
+  optimizer.step()
+  return loss.numpy()
+
+if __name__ == "__main__":
+  # data for training and validation
+  images_real = np.vstack(fetch_mnist()[::2])
+  ds_noise = Tensor.randn(64, 128, requires_grad=False)
+  # parameters
+  epochs, batch_size, k = 300, 512, 1
+  sample_interval = epochs // 10
+  n_steps = len(images_real) // batch_size
+  # models and optimizer
+  generator = LinearGen()
+  discriminator = LinearDisc()
+  # path to store results
+  output_dir = Path(".").resolve() / "outputs"
+  output_dir.mkdir(exist_ok=True)
+  # optimizers
+  optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5)  # 0.0002 for equilibrium!
+  optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5)
+  # training loop
+  Tensor.training = True
+  for epoch in (t := trange(epochs)):
+    loss_g, loss_d = 0.0, 0.0
+    for _ in range(n_steps):
+      data_real = make_batch(images_real)
+      for step in range(k):  # Try with k = 5 or 7.
+        noise = Tensor.randn(batch_size, 128)
+        data_fake = generator.forward(noise).detach()
+        loss_d += train_discriminator(optim_d, data_real, data_fake)
+      noise = Tensor.randn(batch_size, 128)
+      data_fake = generator.forward(noise)
+      loss_g += train_generator(optim_g, data_fake)
+    if (epoch + 1) % sample_interval == 0:
+      fake_images = generator.forward(ds_noise).detach().numpy()
+      fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2  # 0 - 1 range.
+      save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
+    t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
+  print("Training Completed!")

+ 118 - 0
tinychat/examples/openelm.py

@@ -0,0 +1,118 @@
+import json, pprint
+from tinygrad import fetch, nn, Tensor
+from tinygrad.helpers import DEBUG
+
+class FeedForward:
+  def __init__(self, model_dim, intermediate_dim):
+    self.proj_1 = nn.Linear(model_dim, 2*intermediate_dim, bias=False)
+    self.proj_2 = nn.Linear(intermediate_dim, model_dim, bias=False)
+
+  def __call__(self, x):
+    y_12 = self.proj_1(x)
+    y_1, y_2 = y_12.chunk(2, dim=-1)
+    return self.proj_2(y_1.silu() * y_2)
+
+# NOTE: this RoPE doesn't match LLaMA's?
+def _rotate_half(x: Tensor) -> Tensor:
+  x1, x2 = x.chunk(2, dim=-1)
+  return Tensor.cat(-x2, x1, dim=-1)
+
+def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
+  return (x * pos_cos) + (_rotate_half(x) * pos_sin)
+
+class Attention:
+  def __init__(self, model_dim, num_query_heads, num_kv_heads, head_dim):
+    self.qkv_proj = nn.Linear(model_dim, (num_query_heads + num_kv_heads*2) * head_dim, bias=False)
+    self.num_query_heads, self.num_kv_heads = num_query_heads, num_kv_heads
+    self.head_dim = head_dim
+    self.q_norm = nn.RMSNorm(head_dim)
+    self.k_norm = nn.RMSNorm(head_dim)
+    self.out_proj = nn.Linear(num_query_heads * head_dim, model_dim, bias=False)
+
+  def __call__(self, x:Tensor) -> Tensor:
+    batch_size, seq_len, embed_dim = x.shape
+    qkv = self.qkv_proj(x)
+    qkv = qkv.reshape(batch_size, seq_len, self.num_query_heads+self.num_kv_heads*2, self.head_dim).transpose(1, 2)
+    xq,xk,xv = qkv.split([self.num_query_heads, self.num_kv_heads, self.num_kv_heads], dim=1)
+    xq = self.q_norm(xq)
+    xk = self.k_norm(xk)
+
+    # add positional embedding (how many kernels is this?)
+    freq_constant = 10000
+    inv_freq = 1.0 / (freq_constant ** (Tensor.arange(0, self.head_dim, 2) / self.head_dim))
+    pos_index_theta = Tensor.einsum("i,j->ij", Tensor.arange(seq_len), inv_freq)
+    emb = Tensor.cat(pos_index_theta, pos_index_theta, dim=-1)
+    cos_emb, sin_emb = emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
+    xq = _apply_rotary_pos_emb(xq, sin_emb, cos_emb)
+    xk = _apply_rotary_pos_emb(xk, sin_emb, cos_emb)
+
+    # grouped-query attention
+    num_groups = self.num_query_heads // self.num_kv_heads
+    xk = xk.repeat_interleave(num_groups, dim=1)
+    xv = xv.repeat_interleave(num_groups, dim=1)
+
+    # masked attention
+    #start_pos = 0
+    #mask = Tensor.full((1, 1, seq_len, start_pos+seq_len), float("-inf"), dtype=xq.dtype, device=xq.device).triu(start_pos+1)
+    #attn_output = xq.scaled_dot_product_attention(xk, xv, mask).transpose(1, 2)
+
+    # causal is fine, no mask needed
+    attn_output = xq.scaled_dot_product_attention(xk, xv, is_causal=True).transpose(1, 2)
+    return self.out_proj(attn_output.reshape(batch_size, seq_len, self.num_query_heads * self.head_dim))
+
+class Layer:
+  def __init__(self, model_dim, intermediate_dim, num_query_heads, num_kv_heads, head_dim):
+    self.ffn = FeedForward(model_dim, intermediate_dim)
+    self.attn = Attention(model_dim, num_query_heads, num_kv_heads, head_dim)
+    self.ffn_norm = nn.RMSNorm(model_dim)
+    self.attn_norm = nn.RMSNorm(model_dim)
+
+  def __call__(self, x:Tensor) -> Tensor: # (batch, seq_len, embed_dim)
+    x = x + self.attn(self.attn_norm(x))
+    x = x + self.ffn(self.ffn_norm(x))
+    return x
+
+# stupidly complex
+def make_divisible(v, divisor):
+  new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
+  if new_v < 0.9 * v: new_v += divisor
+  return new_v
+
+class Transformer:
+  def __init__(self, cfg):
+    if DEBUG >= 3: pprint.pp(cfg)
+    self.layers = [Layer(cfg['model_dim'], make_divisible(int(cfg["model_dim"] * cfg['ffn_multipliers'][i]), cfg['ffn_dim_divisor']),
+                         cfg['num_query_heads'][i], cfg['num_kv_heads'][i], cfg['head_dim']) for i in range(cfg['num_transformer_layers'])]
+    self.norm = nn.RMSNorm(cfg['model_dim'])
+    self.token_embeddings = nn.Embedding(cfg['vocab_size'], cfg['model_dim'])
+
+  def __call__(self, tokens:Tensor):
+    # _bsz, seqlen = tokens.shape
+    x = self.token_embeddings(tokens)
+    for l in self.layers: x = l(x)
+    return self.norm(x) @ self.token_embeddings.weight.T
+
+if __name__ == "__main__":
+  #model_name = "OpenELM-270M-Instruct"
+  model_name = "OpenELM-270M"  # this is fp32
+  model = Transformer(json.loads(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/config.json?download=true").read_bytes()))
+  weights = nn.state.safe_load(fetch(f"https://huggingface.co/apple/{model_name}/resolve/main/model.safetensors?download=true"))
+  if DEBUG >= 3:
+    for k, v in weights.items(): print(k, v.shape)
+  nn.state.load_state_dict(model, {k.removeprefix("transformer."):v for k,v in weights.items()})
+
+  from sentencepiece import SentencePieceProcessor
+  tokenizer = SentencePieceProcessor(fetch("https://github.com/karpathy/llama2.c/raw/master/tokenizer.model").as_posix())
+  toks = [tokenizer.bos_id()] + tokenizer.encode("Some car brands include")
+  for i in range(100):
+    ttoks = Tensor([toks])
+    out = model(ttoks).realize()
+    t0 = out[0].argmax(axis=-1).tolist()
+    toks.append(t0[-1])
+    # hmmm...passthrough still doesn't match (it shouldn't, it outputs the most likely)
+    print(tokenizer.decode(toks))
+    #print(toks)
+    #print(tokenizer.decode(t0))
+    #print(t0)
+
+

+ 204 - 0
tinychat/examples/openpilot/compile2.py

@@ -0,0 +1,204 @@
+#!/usr/bin/env python3
+import os, sys, io, pathlib, json, struct
+import numpy as np
+sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
+
+if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
+if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
+if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
+
+OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
+
+import onnx
+from typing import Tuple, List, Optional, Dict, cast
+from extra.onnx import get_run_onnx
+from tinygrad import Tensor, Device, GlobalCounters, dtypes
+from tinygrad.dtype import ImageDType
+from tinygrad.device import Buffer
+from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
+from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
+from tinygrad.engine.schedule import ScheduleItem, create_schedule, memory_planner
+from tinygrad.ops import MetaOps
+from tinygrad.tensor import _to_np_dtype
+Device.DEFAULT = "GPU"
+
+def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
+  Tensor.no_grad = True
+  Tensor.training = False
+
+  # load the model
+  onnx_model = onnx.load(io.BytesIO(onnx_data))
+  run_onnx = get_run_onnx(onnx_model)
+  input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
+
+  # run the model
+  inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
+  ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
+  schedule = create_schedule([ret.lazydata])
+
+  # filter schedule that don't depend on the inputs
+  input_lb = [x.lazydata.base.buffer for x in inputs.values()]
+  depends = set(input_lb)
+  for si in schedule:
+    if any(b in depends for b in si.inputs):
+      for out in si.outputs: depends.add(out)
+
+  # run all kernels that don't depend on the inputs
+  # NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
+  schedule, schedule_independent = partition(schedule, lambda si: any(out in depends for out in si.outputs))
+  print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
+
+  # confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
+  assert all(si.ast.op is MetaOps.KERNEL or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
+  return schedule, schedule_independent, inputs
+
+def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
+  import onnx
+  #import pyopencl as cl
+  #from extra.thneed import Thneed
+  import numpy as np
+  onnx_model = onnx.load(io.BytesIO(onnx_data))
+
+  input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
+  Tensor.manual_seed(1337)
+  new_inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
+  new_np_inputs = {k:v.realize().numpy() for k,v in new_inputs.items()}
+
+  if getenv("ORT"):
+    # test with onnxruntime
+    import onnxruntime as ort
+    onnx_session = ort.InferenceSession(onnx_data)
+    onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()})
+    new_torch_out = onnx_output[0]
+    print("got ort outputs")
+  else:
+    # test with torch
+    from test.models.test_onnx import run_onnx_torch
+    new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
+    print("got torch outputs")
+
+  # if you don't have a schedule
+  if eis is None:
+    run_onnx = get_run_onnx(onnx_model)
+    new_tinygrad_out = next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy()
+    np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
+    print("classic self-test passed!")
+    return
+
+  # set inputs
+  for k,v in inputs.items(): v.lazydata.base.realized.copyin(new_np_inputs[k].data)
+
+  # run code (all buffers have been allocated)
+  GlobalCounters.reset()
+  output = eis[-1].bufs[0]
+  for ei in eis: ei.run()
+
+  new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=_to_np_dtype(output.dtype))
+  np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
+  print("semi-thneed self-test passed!")
+
+if __name__ == "__main__":
+  onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL).read_bytes()
+
+  # quick test for ONNX issues
+  #thneed_test_onnx(onnx_data, None)
+  #exit(0)
+
+  schedule, schedule_independent, inputs = get_schedule(onnx_data)
+  schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.KERNEL)
+  print(f"{len(schedule_input)} inputs")
+
+  run_schedule(schedule_independent)
+  run_schedule(schedule_input)
+  with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
+    schedule = memory_planner(schedule)
+    for si in schedule:
+      for b in si.outputs:
+        assert not b.is_allocated(), "output should not be allocated"
+    image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
+    print(f"**** compiling real kernels {image_count}/{len(schedule)} images ****")
+    eis = list(tqdm(lower_schedule(schedule), total=len(schedule)))
+
+  print("kernel count:", len(eis))
+  assert len(eis) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
+
+  # new simple thneed
+  def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")
+
+  seen_buffers = set()
+  input_buffers = [x.lazydata.buffer for x in inputs.values()]
+  jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
+  jdat["inputs"] = {k:to_ref(v.lazydata.buffer) for k,v in inputs.items()}
+  jdat["outputs"] = [to_ref(eis[-1].bufs[0])]
+  weights = []
+  for i,ei in enumerate(eis):
+    #print("***", i)
+    for b in ei.bufs:
+      needs_load = b.is_allocated() and b not in input_buffers
+      #print(b, needs_load)
+      if b in seen_buffers: continue
+      seen_buffers.add(b)
+      if isinstance(b.dtype, ImageDType):
+        base_dtype = dtypes.float16 if b.dtype.fmt == 'e' else dtypes.float32
+        row_pitch = (b.dtype.shape[0]*4*base_dtype.itemsize + 63)//64 * 64
+        size = row_pitch * b.dtype.shape[1]
+        jdat['objects'].append({
+          "id": to_ref(b), "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
+          "width": b.dtype.shape[0], "height": b.dtype.shape[1], "row_pitch": row_pitch, "float32": b.dtype.base == dtypes.float32,
+        })
+        if needs_load:
+          t = Tensor.empty(b.dtype.shape, dtype=b.dtype)
+          t.lazydata.buffer = b
+          data = t.cast(dtypes.float32).pad(((0, row_pitch//(4*base_dtype.itemsize)-b.dtype.shape[0]), (0,0), (0,0))).contiguous().numpy()
+          # NOTE: this cast must be done in numpy for platforms that don't support half
+          if base_dtype == dtypes.float16: data = data.astype(np.float16)
+          weights.append(data.tobytes())
+          assert len(weights[-1]) == size, "wrong size buffer"
+      else:
+        jdat['objects'].append({
+          "id": to_ref(b), "arg_type": b.dtype.name + "*", "needs_load": needs_load, "size": b.nbytes,
+        })
+        if needs_load:
+          weights.append(b.as_buffer())
+          assert len(weights[-1]) == b.nbytes, "wrong size buffer"
+
+  saved_binaries = set()
+  binaries = []
+  GlobalCounters.reset()
+  with Context(DEBUG=max(DEBUG.value, 2)):
+    for ei in eis:
+      prg = cast(CompiledRunner, ei.prg)
+      assert len(prg.p.vars) == 0
+      if prg.p.function_name not in saved_binaries:
+        jdat['binaries'].append({"name":prg.p.function_name, "length":len(prg.lib)})
+        binaries.append(prg.lib)
+        saved_binaries.add(prg.p.function_name)
+      ei.run()
+      jdat['kernels'].append({
+        "name": prg.p.function_name,
+        "work_dim": len(prg.p.global_size),
+        "global_work_size": prg.p.global_size,
+        "local_work_size": prg.p.local_size,
+        "num_args": len(ei.bufs),
+        "args": [to_ref(b) for b in ei.bufs],
+        "arg_size": [8]*len(ei.bufs),
+      })
+
+  output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
+  print(f"saving thneed to {output_fn} with {len(weights)} buffers and {len(binaries)} binaries")
+  with open(output_fn, "wb") as f:
+    j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
+    f.write(struct.pack("I", len(j)))
+    f.write(j)
+    for w in weights: f.write(w)
+    for b in binaries: f.write(b)
+    print("saved", f.tell(), "bytes")
+
+  FLOAT16 = getenv("FLOAT16", 0)
+  if FLOAT16 == 0:
+    try:
+      test_vs_onnx(onnx_data, eis, inputs)
+    except ModuleNotFoundError as e:
+      print(f"TEST NOT HAPPENING {e}")
+
+

+ 2 - 0
tinychat/examples/openpilot/go.sh

@@ -0,0 +1,2 @@
+#!/bin/bash
+NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 examples/openpilot/compile2.py

+ 55 - 0
tinychat/examples/other_mnist/beautiful_mnist_mlx.py

@@ -0,0 +1,55 @@
+from tinygrad.helpers import trange
+from tinygrad.nn.datasets import mnist
+import mlx.core as mx
+import mlx.nn as nn
+import mlx.optimizers as optim
+from functools import partial
+
+class Model(nn.Module):
+  def __init__(self):
+    super().__init__()
+    self.c1 = nn.Conv2d(1, 32, 5)
+    self.c2 = nn.Conv2d(32, 32, 5)
+    self.bn1 = nn.BatchNorm(32)
+    self.m1 = nn.MaxPool2d(2)
+    self.c3 = nn.Conv2d(32, 64, 3)
+    self.c4 = nn.Conv2d(64, 64, 3)
+    self.bn2 = nn.BatchNorm(64)
+    self.m2 = nn.MaxPool2d(2)
+    self.lin = nn.Linear(576, 10)
+  def __call__(self, x):
+    x = mx.maximum(self.c1(x), 0)
+    x = mx.maximum(self.c2(x), 0)
+    x = self.m1(self.bn1(x))
+    x = mx.maximum(self.c3(x), 0)
+    x = mx.maximum(self.c4(x), 0)
+    x = self.m2(self.bn2(x))
+    return self.lin(mx.flatten(x, 1))
+
+if __name__ == "__main__":
+  X_train, Y_train, X_test, Y_test = mnist()
+  X_train = mx.array(X_train.float().permute((0,2,3,1)).numpy())
+  Y_train = mx.array(Y_train.numpy())
+  X_test = mx.array(X_test.float().permute((0,2,3,1)).numpy())
+  Y_test = mx.array(Y_test.numpy())
+
+  model = Model()
+  optimizer = optim.Adam(1e-3)
+  def loss_fn(model, x, y): return nn.losses.cross_entropy(model(x), y).mean()
+
+  state = [model.state, optimizer.state]
+  @partial(mx.compile, inputs=state, outputs=state)
+  def step(samples):
+    # Compiled functions will also treat any inputs not in the parameter list as constants.
+    X,Y = X_train[samples], Y_train[samples]
+    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
+    loss, grads = loss_and_grad_fn(model, X, Y)
+    optimizer.update(model, grads)
+    return loss
+
+  test_acc = float('nan')
+  for i in (t:=trange(70)):
+    samples = mx.random.randint(0, X_train.shape[0], (512,))  # putting this in JIT didn't work well
+    loss = step(samples)
+    if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()
+    t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")

+ 55 - 0
tinychat/examples/other_mnist/beautiful_mnist_torch.py

@@ -0,0 +1,55 @@
+from tinygrad import dtypes
+from tinygrad.helpers import trange
+from tinygrad.nn.datasets import mnist
+import torch
+from torch import nn, optim
+
+class Model(nn.Module):
+  def __init__(self):
+    super().__init__()
+    self.c1 = nn.Conv2d(1, 32, 5)
+    self.c2 = nn.Conv2d(32, 32, 5)
+    self.bn1 = nn.BatchNorm2d(32)
+    self.m1 = nn.MaxPool2d(2)
+    self.c3 = nn.Conv2d(32, 64, 3)
+    self.c4 = nn.Conv2d(64, 64, 3)
+    self.bn2 = nn.BatchNorm2d(64)
+    self.m2 = nn.MaxPool2d(2)
+    self.lin = nn.Linear(576, 10)
+  def forward(self, x):
+    x = nn.functional.relu(self.c1(x))
+    x = nn.functional.relu(self.c2(x), 0)
+    x = self.m1(self.bn1(x))
+    x = nn.functional.relu(self.c3(x), 0)
+    x = nn.functional.relu(self.c4(x), 0)
+    x = self.m2(self.bn2(x))
+    return self.lin(torch.flatten(x, 1))
+
+if __name__ == "__main__":
+  mps_device = torch.device("mps")
+  X_train, Y_train, X_test, Y_test = mnist()
+  X_train = torch.tensor(X_train.float().numpy(), device=mps_device)
+  Y_train = torch.tensor(Y_train.cast(dtypes.int64).numpy(), device=mps_device)
+  X_test = torch.tensor(X_test.float().numpy(), device=mps_device)
+  Y_test = torch.tensor(Y_test.cast(dtypes.int64).numpy(), device=mps_device)
+
+  model = Model().to(mps_device)
+  optimizer = optim.Adam(model.parameters(), 1e-3)
+
+  loss_fn = nn.CrossEntropyLoss()
+  #@torch.compile
+  def step(samples):
+    X,Y = X_train[samples], Y_train[samples]
+    out = model(X)
+    loss = loss_fn(out, Y)
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+    return loss
+
+  test_acc = float('nan')
+  for i in (t:=trange(70)):
+    samples = torch.randint(0, X_train.shape[0], (512,))  # putting this in JIT didn't work well
+    loss = step(samples)
+    if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()
+    t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")

+ 45 - 0
tinychat/examples/rl/lightupbutton.py

@@ -0,0 +1,45 @@
+import gymnasium as gym
+import numpy as np
+from gymnasium.envs.registration import register
+
+# a very simple game
+# one of <size> lights will light up
+# take the action of the lit up light
+# in <hard_mode>, you act differently based on the step number and need to track this
+
+class PressTheLightUpButton(gym.Env):
+  metadata = {"render_modes": []}
+  def __init__(self, render_mode=None, size=2, game_length=10, hard_mode=False):
+    self.size, self.game_length = size, game_length
+    self.observation_space = gym.spaces.Box(0, 1, shape=(self.size,), dtype=np.float32)
+    self.action_space = gym.spaces.Discrete(self.size)
+    self.step_num = 0
+    self.done = True
+    self.hard_mode = hard_mode
+
+  def _get_obs(self):
+    obs = [0]*self.size
+    if self.step_num < len(self.state):
+      obs[self.state[self.step_num]] = 1
+    return np.array(obs, dtype=np.float32)
+
+  def reset(self, seed=None, options=None):
+    super().reset(seed=seed)
+    self.state = np.random.randint(0, self.size, size=self.game_length)
+    self.step_num = 0
+    self.done = False
+    return self._get_obs(), {}
+
+  def step(self, action):
+    target = ((action + self.step_num) % self.size) if self.hard_mode else action
+    reward = int(target == self.state[self.step_num])
+    self.step_num += 1
+    if not reward:
+      self.done = True
+    return self._get_obs(), reward, self.done, self.step_num >= self.game_length, {}
+
+register(
+  id="PressTheLightUpButton-v0",
+  entry_point="examples.rl.lightupbutton:PressTheLightUpButton",
+  max_episode_steps=None,
+)

+ 147 - 0
tinychat/examples/sdv2.py

@@ -0,0 +1,147 @@
+from tinygrad import Tensor, dtypes, TinyJit
+from tinygrad.helpers import fetch
+from tinygrad.nn.state import safe_load, load_state_dict, get_state_dict
+from examples.stable_diffusion import AutoencoderKL, get_alphas_cumprod
+from examples.sdxl import DPMPP2MSampler, append_dims, LegacyDDPMDiscretization
+from extra.models.unet import UNetModel
+from extra.models.clip import FrozenOpenClipEmbedder
+
+from typing import Dict
+import argparse, tempfile, os
+from pathlib import Path
+from PIL import Image
+
+class DiffusionModel:
+  def __init__(self, model:UNetModel):
+    self.diffusion_model = model
+
+@TinyJit
+def run(model, x, tms, ctx, c_out, add):
+  return (model(x, tms, ctx)*c_out + add).realize()
+
+# https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddpm.py#L521
+class StableDiffusionV2:
+  def __init__(self, unet_config:Dict, cond_stage_config:Dict, parameterization:str="v"):
+    self.model             = DiffusionModel(UNetModel(**unet_config))
+    self.first_stage_model = AutoencoderKL()
+    self.cond_stage_model  = FrozenOpenClipEmbedder(**cond_stage_config)
+    self.alphas_cumprod    = get_alphas_cumprod()
+    self.parameterization  = parameterization
+
+    self.discretization = LegacyDDPMDiscretization()
+    self.sigmas = self.discretization(1000, flip=True)
+
+  def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
+
+    def sigma_to_idx(s:Tensor) -> Tensor:
+      dists = s - self.sigmas.unsqueeze(1)
+      return dists.abs().argmin(axis=0).view(*s.shape)
+
+    sigma = self.sigmas[sigma_to_idx(sigma)]
+    sigma_shape = sigma.shape
+    sigma = append_dims(sigma, x)
+
+    c_skip = 1.0 / (sigma**2 + 1.0)
+    c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+    c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+    c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
+
+    def prep(*tensors:Tensor):
+      return tuple(t.cast(dtypes.float16).realize() for t in tensors)
+
+    return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], c_out, x*c_skip))
+
+  def decode(self, x:Tensor, height:int, width:int) -> Tensor:
+    x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
+    x = self.first_stage_model.decoder(x)
+
+    # make image correct size and scale
+    x = (x + 1.0) / 2.0
+    x = x.reshape(3,height,width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
+    return x
+
+params: Dict = {
+  "unet_config": {
+    "adm_in_ch": None,
+    "in_ch": 4,
+    "out_ch": 4,
+    "model_ch": 320,
+    "attention_resolutions": [4, 2, 1],
+    "num_res_blocks": 2,
+    "channel_mult": [1, 2, 4, 4],
+    "d_head": 64,
+    "transformer_depth": [1, 1, 1, 1],
+    "ctx_dim": 1024,
+    "use_linear": True,
+  },
+  "cond_stage_config": {
+    "dims": 1024,
+    "n_heads": 16,
+    "layers": 24,
+    "return_pooled": False,
+    "ln_penultimate": True,
+  }
+}
+
+if __name__ == "__main__":
+  default_prompt = "a horse sized cat eating a bagel"
+  parser = argparse.ArgumentParser(description='Run Stable Diffusion v2.X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument('--steps',       type=int,   default=10, help="The number of diffusion steps")
+  parser.add_argument('--prompt',      type=str,   default=default_prompt, help="Description of image to generate")
+  parser.add_argument('--out',         type=str,   default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
+  parser.add_argument('--seed',        type=int,   help="Set the random latent seed")
+  parser.add_argument('--guidance',    type=float, default=7.5, help="Prompt strength")
+  parser.add_argument('--width',       type=int,   default=768, help="The output image width")
+  parser.add_argument('--height',      type=int,   default=768, help="The output image height")
+  parser.add_argument('--weights-fn',  type=str,   help="Filename of weights to use")
+  parser.add_argument('--weights-url', type=str,   help="Custom URL to download weights from")
+  parser.add_argument('--timing',      action='store_true', help="Print timing per step")
+  parser.add_argument('--noshow',      action='store_true', help="Don't show the image")
+  parser.add_argument('--fp16',        action='store_true', help="Cast the weights to float16")
+  args = parser.parse_args()
+
+  N = 1
+  C = 4
+  F = 8
+  assert args.width  % F == 0, f"img_width must be multiple of {F}, got {args.width}"
+  assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
+
+  Tensor.no_grad = True
+  if args.seed is not None:
+    Tensor.manual_seed(args.seed)
+
+  model = StableDiffusionV2(**params)
+
+  default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
+  weights_fn = args.weights_fn
+  if not weights_fn:
+    weights_url = args.weights_url if args.weights_url else default_weights_url
+    weights_fn  = fetch(weights_url, os.path.basename(str(weights_url)))
+  load_state_dict(model, safe_load(weights_fn), strict=False)
+
+  if args.fp16:
+    for k,v in get_state_dict(model).items():
+      if k.startswith("model"):
+        v.replace(v.cast(dtypes.float16).realize())
+
+  c  = { "crossattn": model.cond_stage_model(args.prompt) }
+  uc = { "crossattn": model.cond_stage_model("") }
+  del model.cond_stage_model
+  print("created conditioning")
+
+  shape = (N, C, args.height // F, args.width // F)
+  randn = Tensor.randn(shape)
+
+  sampler = DPMPP2MSampler(args.guidance)
+  z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
+  print("created samples")
+  x = model.decode(z, args.height, args.width).realize()
+  print("decoded samples")
+  print(x.shape)
+
+  im = Image.fromarray(x.numpy())
+  print(f"saving {args.out}")
+  im.save(args.out)
+
+  if not args.noshow:
+    im.show()

+ 428 - 0
tinychat/examples/sdxl.py

@@ -0,0 +1,428 @@
+# This file incorporates code from the following:
+# Github Name                    | License | Link
+# Stability-AI/generative-models | MIT     | https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/LICENSE-CODE
+# mlfoundations/open_clip        | MIT     | https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/LICENSE
+
+from tinygrad import Tensor, TinyJit, dtypes
+from tinygrad.nn import Conv2d, GroupNorm
+from tinygrad.nn.state import safe_load, load_state_dict
+from tinygrad.helpers import fetch, trange, colored, Timing, GlobalCounters
+from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
+from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
+from examples.stable_diffusion import ResnetBlock, Mid
+import numpy as np
+
+from typing import Dict, List, Callable, Optional, Any, Set, Tuple
+import argparse, tempfile
+from abc import ABC, abstractmethod
+from pathlib import Path
+from PIL import Image
+
+
+# configs:
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_base.yaml
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_refiner.yaml
+configs: Dict = {
+  "SDXL_Base": {
+    "model": {"adm_in_ch": 2816, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048, "use_linear": True},
+    "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "target_size_as_tuple"]},
+    "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
+    "denoiser": {"num_idx": 1000},
+  },
+  "SDXL_Refiner": {
+    "model": {"adm_in_ch": 2560, "in_ch": 4, "out_ch": 4, "model_ch": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280], "use_linear": True},
+    "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "aesthetic_score"]},
+    "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
+    "denoiser": {"num_idx": 1000},
+  }
+}
+
+
+def tensor_identity(x:Tensor) -> Tensor:
+  return x
+
+
+class DiffusionModel:
+  def __init__(self, *args, **kwargs):
+    self.diffusion_model = UNetModel(*args, **kwargs)
+
+
+class Embedder(ABC):
+  input_key: str
+  @abstractmethod
+  def __call__(self, x:Tensor) -> Tensor:
+    pass
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913
+class ConcatTimestepEmbedderND(Embedder):
+  def __init__(self, outdim:int, input_key:str):
+    self.outdim = outdim
+    self.input_key = input_key
+
+  def __call__(self, x:Tensor):
+    assert len(x.shape) == 2
+    emb = timestep_embedding(x.flatten(), self.outdim)
+    emb = emb.reshape((x.shape[0],-1))
+    return emb
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L71
+class Conditioner:
+  OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
+  KEY2CATDIM      = {"vector": 1, "crossattn": 2, "concat": 1}
+  embedders: List[Embedder]
+
+  def __init__(self, concat_embedders:List[str]):
+    self.embedders = [
+      FrozenClosedClipEmbedder(ret_layer_idx=11),
+      FrozenOpenClipEmbedder(dims=1280, n_heads=20, layers=32, return_pooled=True),
+    ]
+    for input_key in concat_embedders:
+      self.embedders.append(ConcatTimestepEmbedderND(256, input_key))
+
+  def get_keys(self) -> Set[str]:
+    return set(e.input_key for e in self.embedders)
+
+  def __call__(self, batch:Dict, force_zero_embeddings:List=[]) -> Dict[str,Tensor]:
+    output: Dict[str,Tensor] = {}
+
+    for embedder in self.embedders:
+      emb_out = embedder(batch[embedder.input_key])
+
+      if isinstance(emb_out, Tensor):
+        emb_out = [emb_out]
+      else:
+        assert isinstance(emb_out, (list, tuple))
+
+      for emb in emb_out:
+        if embedder.input_key in force_zero_embeddings:
+          emb = Tensor.zeros_like(emb)
+
+        out_key = self.OUTPUT_DIM2KEYS[len(emb.shape)]
+        if out_key in output:
+          output[out_key] = Tensor.cat(output[out_key], emb, dim=self.KEY2CATDIM[out_key])
+        else:
+          output[out_key] = emb
+
+    return output
+
+
+class FirstStage:
+  """
+  Namespace for First Stage Model components
+  """
+
+  # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L487
+  class Encoder:
+    def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
+      self.conv_in = Conv2d(in_ch, ch, kernel_size=3, stride=1, padding=1)
+      in_ch_mult = (1,) + tuple(ch_mult)
+
+      class BlockEntry:
+        def __init__(self, block:List[ResnetBlock], downsample):
+          self.block = block
+          self.downsample = downsample
+      self.down: List[BlockEntry] = []
+      for i_level in range(len(ch_mult)):
+        block = []
+        block_in  = ch * in_ch_mult[i_level]
+        block_out = ch * ch_mult   [i_level]
+        for _ in range(num_res_blocks):
+          block.append(ResnetBlock(block_in, block_out))
+          block_in = block_out
+
+        downsample = tensor_identity if (i_level == len(ch_mult)-1) else Downsample(block_in)
+        self.down.append(BlockEntry(block, downsample))
+
+      self.mid = Mid(block_in)
+
+      self.norm_out = GroupNorm(32, block_in)
+      self.conv_out = Conv2d(block_in, 2*z_ch, kernel_size=3, stride=1, padding=1)
+
+    def __call__(self, x:Tensor) -> Tensor:
+      h = self.conv_in(x)
+      for down in self.down:
+        for block in down.block:
+          h = block(h)
+        h = down.downsample(h)
+
+      h = h.sequential([self.mid.block_1, self.mid.attn_1, self.mid.block_2])
+      h = h.sequential([self.norm_out,    Tensor.swish,    self.conv_out   ])
+      return h
+
+
+  # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L604
+  class Decoder:
+    def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
+      block_in = ch * ch_mult[-1]
+      curr_res = resolution // 2 ** (len(ch_mult) - 1)
+      self.z_shape = (1, z_ch, curr_res, curr_res)
+
+      self.conv_in = Conv2d(z_ch, block_in, kernel_size=3, stride=1, padding=1)
+
+      self.mid = Mid(block_in)
+
+      class BlockEntry:
+        def __init__(self, block:List[ResnetBlock], upsample:Callable[[Any],Any]):
+          self.block = block
+          self.upsample = upsample
+      self.up: List[BlockEntry] = []
+      for i_level in reversed(range(len(ch_mult))):
+        block = []
+        block_out = ch * ch_mult[i_level]
+        for _ in range(num_res_blocks + 1):
+          block.append(ResnetBlock(block_in, block_out))
+          block_in = block_out
+
+        upsample = tensor_identity if i_level == 0 else Upsample(block_in)
+        self.up.insert(0, BlockEntry(block, upsample)) # type: ignore
+
+      self.norm_out = GroupNorm(32, block_in)
+      self.conv_out = Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+    def __call__(self, z:Tensor) -> Tensor:
+      h = z.sequential([self.conv_in, self.mid.block_1, self.mid.attn_1, self.mid.block_2])
+
+      for up in self.up[::-1]:
+        for block in up.block:
+          h = block(h)
+        h = up.upsample(h)
+
+      h = h.sequential([self.norm_out, Tensor.swish, self.conv_out])
+      return h
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L102
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L437
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L508
+class FirstStageModel:
+  def __init__(self, embed_dim:int=4, **kwargs):
+    self.encoder = FirstStage.Encoder(**kwargs)
+    self.decoder = FirstStage.Decoder(**kwargs)
+    self.quant_conv = Conv2d(2*kwargs["z_ch"], 2*embed_dim, 1)
+    self.post_quant_conv = Conv2d(embed_dim, kwargs["z_ch"], 1)
+
+  def decode(self, z:Tensor) -> Tensor:
+    return z.sequential([self.post_quant_conv, self.decoder])
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/discretizer.py#L42
+class LegacyDDPMDiscretization:
+  def __init__(self, linear_start:float=0.00085, linear_end:float=0.0120, num_timesteps:int=1000):
+    self.num_timesteps = num_timesteps
+    betas = np.linspace(linear_start**0.5, linear_end**0.5, num_timesteps, dtype=np.float32) ** 2
+    alphas = 1.0 - betas
+    self.alphas_cumprod = np.cumprod(alphas, axis=0)
+
+  def __call__(self, n:int, flip:bool=False) -> Tensor:
+    if n < self.num_timesteps:
+      timesteps = np.linspace(self.num_timesteps - 1, 0, n, endpoint=False).astype(int)[::-1]
+      alphas_cumprod = self.alphas_cumprod[timesteps]
+    elif n == self.num_timesteps:
+      alphas_cumprod = self.alphas_cumprod
+    sigmas = Tensor((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
+    sigmas = Tensor.cat(Tensor.zeros((1,)), sigmas)
+    return sigmas if flip else sigmas.flip(axis=0) # sigmas is "pre-flipped", need to do oposite of flag
+
+
+def append_dims(x:Tensor, t:Tensor) -> Tensor:
+  dims_to_append = len(t.shape) - len(x.shape)
+  assert dims_to_append >= 0
+  return x.reshape(x.shape + (1,)*dims_to_append)
+
+
+@TinyJit
+def run(model, x, tms, ctx, y, c_out, add):
+  return (model(x, tms, ctx, y)*c_out + add).realize()
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/diffusion.py#L19
+class SDXL:
+  def __init__(self, config:Dict):
+    self.conditioner = Conditioner(**config["conditioner"])
+    self.first_stage_model = FirstStageModel(**config["first_stage_model"])
+    self.model = DiffusionModel(**config["model"])
+
+    self.discretization = LegacyDDPMDiscretization()
+    self.sigmas = self.discretization(config["denoiser"]["num_idx"], flip=True)
+
+  # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L173
+  def create_conditioning(self, pos_prompt:str, img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
+    batch_c : Dict = {
+      "txt": pos_prompt,
+      "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
+      "crop_coords_top_left": Tensor([0,0]).repeat(N,1),
+      "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
+      "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
+    }
+    batch_uc: Dict = {
+      "txt": "",
+      "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
+      "crop_coords_top_left": Tensor([0,0]).repeat(N,1),
+      "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
+      "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
+    }
+    return model.conditioner(batch_c), model.conditioner(batch_uc, force_zero_embeddings=["txt"])
+
+  # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/denoiser.py#L42
+  def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
+
+    def sigma_to_idx(s:Tensor) -> Tensor:
+      dists = s - self.sigmas.unsqueeze(1)
+      return dists.abs().argmin(axis=0).view(*s.shape)
+
+    sigma = self.sigmas[sigma_to_idx(sigma)]
+    sigma_shape = sigma.shape
+    sigma = append_dims(sigma, x)
+
+    c_out   = -sigma
+    c_in    = 1 / (sigma**2 + 1.0) ** 0.5
+    c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
+
+    def prep(*tensors:Tensor):
+      return tuple(t.cast(dtypes.float16).realize() for t in tensors)
+
+    return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], cond["vector"], c_out, x))
+
+  def decode(self, x:Tensor) -> Tensor:
+    return self.first_stage_model.decode(1.0 / 0.13025 * x)
+
+
+class VanillaCFG:
+  def __init__(self, scale:float):
+    self.scale = scale
+
+  def prepare_inputs(self, x:Tensor, s:float, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Tensor]:
+    c_out = {}
+    for k in c:
+      assert k in ["vector", "crossattn", "concat"]
+      c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
+    return Tensor.cat(x, x), Tensor.cat(s, s), c_out
+
+  def __call__(self, x:Tensor, sigma:float) -> Tensor:
+    x_u, x_c = x.chunk(2)
+    x_pred = x_u + self.scale*(x_c - x_u)
+    return x_pred
+
+
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
+# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
+class DPMPP2MSampler:
+  def __init__(self, cfg_scale:float):
+    self.discretization = LegacyDDPMDiscretization()
+    self.guider = VanillaCFG(cfg_scale)
+
+  def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
+    denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
+    denoised = self.guider(denoised, sigma)
+
+    t, t_next = sigma.log().neg(), next_sigma.log().neg()
+    h = t_next - t
+    r = None if prev_sigma is None else (t - prev_sigma.log().neg()) / h
+
+    mults = [t_next.neg().exp()/t.neg().exp(), (-h).exp().sub(1)]
+    if r is not None:
+      mults.extend([1 + 1/(2*r), 1/(2*r)])
+    mults = [append_dims(m, x) for m in mults]
+
+    x_standard = mults[0]*x - mults[1]*denoised
+    if (old_denoised is None) or (next_sigma.sum().numpy().item() < 1e-14):
+      return x_standard, denoised
+
+    denoised_d = mults[2]*denoised - mults[3]*old_denoised
+    x_advanced = mults[0]*x        - mults[1]*denoised_d
+    x = Tensor.where(append_dims(next_sigma, x) > 0.0, x_advanced, x_standard)
+    return x, denoised
+
+  def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
+    sigmas = self.discretization(num_steps)
+    x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
+    num_sigmas = len(sigmas)
+
+    old_denoised = None
+    for i in trange(num_sigmas - 1):
+      with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
+        x, old_denoised = self.sampler_step(
+          old_denoised=old_denoised,
+          prev_sigma=(None if i==0 else sigmas[i-1].reshape(x.shape[0])),
+          sigma=sigmas[i].reshape(x.shape[0]),
+          next_sigma=sigmas[i+1].reshape(x.shape[0]),
+          denoiser=denoiser,
+          x=x,
+          c=c,
+          uc=uc,
+        )
+        x.realize()
+        old_denoised.realize()
+
+    return x
+
+
+if __name__ == "__main__":
+  default_prompt = "a horse sized cat eating a bagel"
+  parser = argparse.ArgumentParser(description="Run SDXL", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument('--steps',    type=int,   default=10, help="The number of diffusion steps")
+  parser.add_argument('--prompt',   type=str,   default=default_prompt, help="Description of image to generate")
+  parser.add_argument('--out',      type=str,   default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
+  parser.add_argument('--seed',     type=int,   help="Set the random latent seed")
+  parser.add_argument('--guidance', type=float, default=6.0, help="Prompt strength")
+  parser.add_argument('--width',    type=int,   default=1024, help="The output image width")
+  parser.add_argument('--height',   type=int,   default=1024, help="The output image height")
+  parser.add_argument('--weights',  type=str,   help="Custom path to weights")
+  parser.add_argument('--timing',   action='store_true', help="Print timing per step")
+  parser.add_argument('--noshow',   action='store_true', help="Don't show the image")
+  args = parser.parse_args()
+
+  Tensor.no_grad = True
+  if args.seed is not None:
+    Tensor.manual_seed(args.seed)
+
+  model = SDXL(configs["SDXL_Base"])
+
+  default_weight_url = 'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors'
+  weights = args.weights if args.weights else fetch(default_weight_url, 'sd_xl_base_1.0.safetensors')
+  load_state_dict(model, safe_load(weights), strict=False)
+
+  N = 1
+  C = 4
+  F = 8
+
+  assert args.width  % F == 0, f"img_width must be multiple of {F}, got {args.width}"
+  assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
+
+  c, uc = model.create_conditioning(args.prompt, args.width, args.height)
+  del model.conditioner
+  for v in c .values(): v.realize()
+  for v in uc.values(): v.realize()
+  print("created batch")
+
+  # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L101
+  shape = (N, C, args.height // F, args.width // F)
+  randn = Tensor.randn(shape)
+
+  sampler = DPMPP2MSampler(args.guidance)
+  z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
+  print("created samples")
+  x = model.decode(z).realize()
+  print("decoded samples")
+
+  # make image correct size and scale
+  x = (x + 1.0) / 2.0
+  x = x.reshape(3,args.height,args.width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
+  print(x.shape)
+
+  im = Image.fromarray(x.numpy())
+  print(f"saving {args.out}")
+  im.save(args.out)
+
+  if not args.noshow:
+    im.show()
+
+  # validation!
+  if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \
+    and not args.weights:
+    ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
+    distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
+    assert distance < 2e-3, colored(f"validation failed with {distance=}", "red")
+    print(colored(f"output validated with {distance=}", "green"))

BIN=BIN
tinychat/examples/sdxl_seed0.png


+ 136 - 0
tinychat/examples/serious_mnist.py

@@ -0,0 +1,136 @@
+#!/usr/bin/env python
+#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
+import sys
+import numpy as np
+from tinygrad.nn.state import get_parameters
+from tinygrad.tensor import Tensor
+from tinygrad.nn import BatchNorm2d, optim
+from tinygrad.helpers import getenv
+from extra.datasets import fetch_mnist
+from extra.augment import augment_img
+from extra.training import train, evaluate
+GPU = getenv("GPU")
+QUICK = getenv("QUICK")
+DEBUG = getenv("DEBUG")
+
+class SqueezeExciteBlock2D:
+  def __init__(self, filters):
+    self.filters = filters
+    self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32)
+    self.bias1 = Tensor.scaled_uniform(1,self.filters//32)
+    self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters)
+    self.bias2 = Tensor.scaled_uniform(1, self.filters)
+
+  def __call__(self, input):
+    se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
+    se = se.reshape(shape=(-1, self.filters))
+    se = se.dot(self.weight1) + self.bias1
+    se = se.relu()
+    se = se.dot(self.weight2) + self.bias2
+    se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
+    se = input.mul(se)
+    return se
+
+class ConvBlock:
+  def __init__(self, h, w, inp, filters=128, conv=3):
+    self.h, self.w = h, w
+    self.inp = inp
+    #init weights
+    self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
+    self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
+    #init layers
+    self._bn = BatchNorm2d(128)
+    self._seb = SqueezeExciteBlock2D(filters)
+
+  def __call__(self, input):
+    x = input.reshape(shape=(-1, self.inp, self.w, self.h))
+    for cweight, cbias in zip(self.cweights, self.cbiases):
+      x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
+    x = self._bn(x)
+    x = self._seb(x)
+    return x
+
+class BigConvNet:
+  def __init__(self):
+    self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
+    self.weight1 = Tensor.scaled_uniform(128,10)
+    self.weight2 = Tensor.scaled_uniform(128,10)
+
+  def parameters(self):
+    if DEBUG: #keeping this for a moment
+      pars = [par for par in get_parameters(self) if par.requires_grad]
+      no_pars = 0
+      for par in pars:
+        print(par.shape)
+        no_pars += np.prod(par.shape)
+      print('no of parameters', no_pars)
+      return pars
+    else:
+      return get_parameters(self)
+
+  def save(self, filename):
+    with open(filename+'.npy', 'wb') as f:
+      for par in get_parameters(self):
+        #if par.requires_grad:
+        np.save(f, par.numpy())
+
+  def load(self, filename):
+    with open(filename+'.npy', 'rb') as f:
+      for par in get_parameters(self):
+        #if par.requires_grad:
+        try:
+          par.numpy()[:] = np.load(f)
+          if GPU:
+            par.gpu()
+        except:
+          print('Could not load parameter')
+
+  def forward(self, x):
+    x = self.conv[0](x)
+    x = self.conv[1](x)
+    x = x.avg_pool2d(kernel_size=(2,2))
+    x = self.conv[2](x)
+    x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
+    x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
+    xo = x1.dot(self.weight1) + x2.dot(self.weight2)
+    return xo
+
+
+if __name__ == "__main__":
+  lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
+  epochss = [2, 1] if QUICK else [13, 3, 3, 1]
+  BS = 32
+
+  lmbd = 0.00025
+  lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
+  X_train, Y_train, X_test, Y_test = fetch_mnist()
+  X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
+  X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
+  steps = len(X_train)//BS
+  np.random.seed(1337)
+  if QUICK:
+    steps = 1
+    X_test, Y_test = X_test[:BS], Y_test[:BS]
+
+  model = BigConvNet()
+
+  if len(sys.argv) > 1:
+    try:
+      model.load(sys.argv[1])
+      print('Loaded weights "'+sys.argv[1]+'", evaluating...')
+      evaluate(model, X_test, Y_test, BS=BS)
+    except:
+      print('could not load weights "'+sys.argv[1]+'".')
+
+  if GPU:
+    params = get_parameters(model)
+    [x.gpu_() for x in params]
+
+  for lr, epochs in zip(lrs, epochss):
+    optimizer = optim.Adam(model.parameters(), lr=lr)
+    for epoch in range(1,epochs+1):
+      #first epoch without augmentation
+      X_aug = X_train if epoch == 1 else augment_img(X_train)
+      train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
+      accuracy = evaluate(model, X_test, Y_test, BS=BS)
+      model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')

+ 17 - 0
tinychat/examples/simple_conv_bn.py

@@ -0,0 +1,17 @@
+from tinygrad.tensor import Tensor
+from tinygrad.nn import Conv2d, BatchNorm2d
+from tinygrad.nn.state import get_parameters
+
+if __name__ == "__main__":
+  with Tensor.train():
+
+    BS, C1, H, W = 4, 16, 224, 224
+    C2, K, S, P = 64, 7, 2, 1
+
+    x = Tensor.uniform(BS, C1, H, W)
+    conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
+    bn = BatchNorm2d(C2, track_running_stats=False)
+    for t in get_parameters([x, conv, bn]): t.realize()
+
+    print("running network")
+    x.sequential([conv, bn]).numpy()

+ 673 - 0
tinychat/examples/so_vits_svc.py

@@ -0,0 +1,673 @@
+# original implementation: https://github.com/svc-develop-team/so-vits-svc
+from __future__ import annotations
+import sys, logging, time, io, math, argparse, operator, numpy as np
+from functools import partial, reduce
+from pathlib import Path
+from typing import Tuple, Optional, Type
+from tinygrad import nn, dtypes, Tensor
+from tinygrad.helpers import getenv
+from tinygrad.nn.state import torch_load
+from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, get_hparams_from_file, load_checkpoint, weight_norm, HParams
+from examples.sovits_helpers import preprocess
+import soundfile
+
+DEBUG = getenv("DEBUG")
+
+F0_BIN = 256
+F0_MAX = 1100.0
+F0_MIN = 50.0
+F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700)
+F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700)
+
+def download_if_not_present(file_path: Path, url: str):
+  if not os.path.isfile(file_path): download_file(url, file_path)
+  return file_path
+
+class SpeechEncoder:
+  def __init__(self, hidden_dim, model:ContentVec): self.hidden_dim, self.model = hidden_dim, model
+  def encode(self, ): raise NotImplementedError("implement me")
+  @classmethod
+  def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
+    contentvec = ContentVec.load_from_pretrained(checkpoint_path, checkpoint_url)
+    return cls(contentvec)
+
+class ContentVec256L9(SpeechEncoder):
+  def __init__(self, model:ContentVec): super().__init__(hidden_dim=256, model=model)
+  def encode(self, wav: Tensor):
+    feats = wav
+    if len(feats.shape) == 2:  # double channels
+      feats = feats.mean(-1)
+    assert len(feats.shape) == 1, feats.dim()
+    feats = feats.reshape(1, -1)
+    padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
+    logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=9)
+    feats = self.model.final_proj(logits[0])
+    return feats.transpose(1,2)
+
+class ContentVec768L12(SpeechEncoder):
+  def __init__(self, model:ContentVec): super().__init__(hidden_dim=768, model=model)
+  def encode(self, wav: Tensor):
+    feats = wav
+    if len(feats.shape) == 2:  # double channels
+      feats = feats.mean(-1)
+    assert len(feats.shape) == 1, feats.dim()
+    feats = feats.reshape(1, -1)
+    padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
+    logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=12)
+    return logits[0].transpose(1,2)
+
+# original code for contentvec: https://github.com/auspicious3000/contentvec/
+class ContentVec:
+  # self.final_proj dims are hardcoded and depend on fairseq.data.dictionary Dictionary in the checkpoint. This param can't yet be loaded since there is no pickle for it. See with DEBUG=2.
+  # This means that the ContentVec only works with the hubert weights used in all SVC models
+  def __init__(self, cfg: HParams):
+    self.feature_grad_mult, self.untie_final_proj = cfg.feature_grad_mult, cfg.untie_final_proj
+    feature_enc_layers = eval(cfg.conv_feature_layers)
+    self.embed = feature_enc_layers[-1][0]
+    final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
+    self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
+    self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
+    self.encoder = TransformerEncoder(cfg)
+    self.layer_norm = nn.LayerNorm(self.embed)
+    self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * 1) if self.untie_final_proj else nn.Linear(cfg.encoder_embed_dim, final_dim)
+    self.mask_emb = Tensor.uniform(cfg.encoder_embed_dim, dtype=dtypes.float32)
+    self.label_embs_concat = Tensor.uniform(504, final_dim, dtype=dtypes.float32)
+  def forward_features(self, source, padding_mask):
+    if self.feature_grad_mult > 0:
+      features = self.feature_extractor(source, padding_mask)
+      if self.feature_grad_mult != 1.0: pass  # training: GradMultiply.forward(features, self.feature_grad_mult)
+    else:
+      features = self.feature_extractor(source, padding_mask)
+    return features
+  def forward_padding_mask(self, features, padding_mask):  # replaces original forward_padding_mask for batch inference
+    lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1)  # ensure its bool for tilde
+    lengths = (lengths_org - 400).float().div(320).floor().cast(dtypes.int64) + 1  # intermediate float to divide
+    padding_mask = lengths_to_padding_mask(lengths)
+    return padding_mask
+  def extract_features(self, source: Tensor, spk_emb:Tensor=None, padding_mask=None, ret_conv=False, output_layer=None, tap=False):
+    features = self.forward_features(source, padding_mask)
+    if padding_mask is not None:
+      padding_mask = self.forward_padding_mask(features, padding_mask)
+    features = features.transpose(1, 2)
+    features = self.layer_norm(features)
+    if self.post_extract_proj is not None:
+      features = self.post_extract_proj(features)
+    x, _ = self.encoder(features, spk_emb, padding_mask=padding_mask, layer=(None if output_layer is None else output_layer - 1), tap=tap)
+    res = features if ret_conv else x
+    return res, padding_mask
+  @classmethod
+  def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
+    download_if_not_present(checkpoint_path, checkpoint_url)
+    cfg = load_fairseq_cfg(checkpoint_path)
+    enc = cls(cfg.model)
+    _ = load_checkpoint_enc(checkpoint_path, enc, None)
+    logging.debug(f"{cls.__name__}: Loaded model with cfg={cfg}")
+    return enc
+
+class TransformerEncoder:
+  def __init__(self, cfg: HParams):
+    def make_conv() -> nn.Conv1d:
+      layer = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=cfg.conv_pos, padding=cfg.conv_pos // 2, groups=cfg.conv_pos_groups)
+      std = std = math.sqrt(4 / (cfg.conv_pos * self.embedding_dim))
+      layer.weight, layer.bias = (Tensor.normal(*layer.weight.shape, std=std)), (Tensor.zeros(*layer.bias.shape))
+      # for training: layer.weights need to be weight_normed
+      return layer
+    self.dropout, self.embedding_dim, self.layer_norm_first, self.layerdrop, self.num_layers, self.num_layers_1 = cfg.dropout, cfg.encoder_embed_dim, cfg.layer_norm_first, cfg.encoder_layerdrop, cfg.encoder_layers, cfg.encoder_layers_1
+    self.pos_conv, self.pos_conv_remove = [make_conv()], (1 if cfg.conv_pos % 2 == 0 else 0)
+    self.layers = [
+      TransformerEncoderLayer(self.embedding_dim, cfg.encoder_ffn_embed_dim, cfg.encoder_attention_heads, self.dropout, cfg.attention_dropout, cfg.activation_dropout, cfg.activation_fn, self.layer_norm_first, cond_layer_norm=(i >= cfg.encoder_layers))
+      for i in range(cfg.encoder_layers + cfg.encoder_layers_1)
+      ]
+    self.layer_norm = nn.LayerNorm(self.embedding_dim)
+    self.cond_layer_norm = CondLayerNorm(self.embedding_dim) if cfg.encoder_layers_1 > 0 else None
+    # training: apply init_bert_params
+  def __call__(self, x, spk_emb, padding_mask=None, layer=None, tap=False):
+    x, layer_results = self.extract_features(x, spk_emb, padding_mask, layer, tap)
+    if self.layer_norm_first and layer is None:
+      x = self.cond_layer_norm(x, spk_emb) if (self.num_layers_1 > 0) else self.layer_norm(x)
+    return x, layer_results
+  def extract_features(self, x: Tensor, spk_emb: Tensor, padding_mask=None, tgt_layer=None, tap=False):
+    if tgt_layer is not None:  # and not self.training
+      assert tgt_layer >= 0 and tgt_layer < len(self.layers)
+    if padding_mask is not None:
+      # x[padding_mask] = 0
+      assert padding_mask.shape == x.shape[:len(padding_mask.shape)]  # first few dims of x must match padding_mask
+      tmp_mask = padding_mask.unsqueeze(-1).repeat((1, 1, x.shape[-1]))
+      tmp_mask = tilde(tmp_mask.cast(dtypes.bool))
+      x = tmp_mask.where(x, 0)
+    x_conv = self.pos_conv[0](x.transpose(1,2))
+    if self.pos_conv_remove > 0: x_conv = x_conv[:, :, : -self.pos_conv_remove]
+    x_conv = x_conv.gelu().transpose(1, 2)
+    x = (x + x_conv).transpose(0, 1)  # B x T x C -> T x B x C
+    if not self.layer_norm_first: x = self.layer_norm(x)
+    x = x.dropout(p=self.dropout)
+    layer_results = []
+    r = None
+    for i, layer in enumerate(self.layers):
+      if i < self.num_layers:  # if (not self.training or (dropout_probability > self.layerdrop)) and (i < self.num_layers):
+        assert layer.cond_layer_norm == False
+        x = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
+        if tgt_layer is not None or tap:
+          layer_results.append(x.transpose(0, 1))
+      if i>= self.num_layers:
+        assert layer.cond_layer_norm == True
+        x = layer(x, emb=spk_emb, self_attn_padding_mask=padding_mask, need_weights=False)
+      if i == tgt_layer:
+        r = x
+        break
+    if r is not None:
+      x = r
+    x = x.transpose(0, 1)  # T x B x C -> B x T x C
+    return x, layer_results
+
+class TransformerEncoderLayer:
+  def __init__(self, embedding_dim=768.0, ffn_embedding_dim=3072.0, num_attention_heads=8.0, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, activation_fn="relu", layer_norm_first=False, cond_layer_norm=False):
+    def get_activation_fn(activation):
+      if activation == "relu": return Tensor.relu
+      if activation == "gelu": return Tensor.gelu
+      else: raise RuntimeError(f"activation function={activation} is not forseen")
+    self.embedding_dim, self.dropout, self.activation_dropout, self.layer_norm_first, self.num_attention_heads, self.cond_layer_norm, self.activation_fn = embedding_dim, dropout, activation_dropout, layer_norm_first, num_attention_heads, cond_layer_norm, get_activation_fn(activation_fn)
+    self.self_attn = MultiHeadAttention(self.embedding_dim, self.num_attention_heads)
+    self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
+    self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+    self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+    self.final_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
+  def __call__(self, x:Tensor, self_attn_mask:Tensor=None, self_attn_padding_mask:Tensor=None, emb:Tensor=None, need_weights=False):
+    #self_attn_padding_mask = self_attn_padding_mask.reshape(x.shape[0], 1, 1, self_attn_padding_mask.shape[1]).expand(-1, self.num_attention_heads, -1, -1).reshape(x.shape[0] * self.num_attention_heads, 1, self_attn_padding_mask.shape[1]) if self_attn_padding_mask is not None else None
+    assert self_attn_mask is None and self_attn_padding_mask is not None
+    residual = x
+    if self.layer_norm_first:
+      x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
+      x = self.self_attn(x=x, mask=self_attn_padding_mask)
+      x = x.dropout(self.dropout)
+      x = residual + x
+      x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
+      x = self.activation_fn(self.fc1(x))
+      x = x.dropout(self.activation_dropout)
+      x = self.fc2(x)
+      x = x.dropout(self.dropout)
+      x = residual + x
+    else:
+      x = self.self_attn(x=x, mask=self_attn_padding_mask)
+      x = x.dropout(self.dropout)
+      x = residual + x
+      x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
+      residual = x
+      x = self.activation_fn(self.fc1(x))
+      x = x.dropout(self.activation_dropout)
+      x = self.fc2(x)
+      x = x.dropout(self.dropout)
+      x = residual + x
+      x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
+    return x
+
+class MultiHeadAttention:
+  def __init__(self, n_state, n_head):
+    self.n_state, self.n_head = n_state, n_head
+    self.q_proj, self.k_proj, self.v_proj, self.out_proj = [nn.Linear(n_state, n_state) for _ in range(4)]
+  def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None):
+    x = x.transpose(0,1)  # TxBxC -> BxTxC
+    q, k, v = self.q_proj(x), self.k_proj(xa or x), self.v_proj(xa or x)
+    q, k, v = [x.reshape(*q.shape[:2], self.n_head, -1) for x in (q, k, v)]
+    wv = Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None).transpose(1, 2).reshape(*x.shape[:2], -1)
+    ret =  self.out_proj(wv).transpose(0,1)  # BxTxC -> TxBxC
+    return ret
+
+class ConvFeatureExtractionModel:
+  def __init__(self, conv_layers, dropout=.0, mode="default", conv_bias=False):
+    assert mode in {"default", "group_norm_masked", "layer_norm"}
+    def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
+      def make_conv():
+        conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+        conv.weight = Tensor.kaiming_normal(*conv.weight.shape)
+        return conv
+      assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
+      if is_layer_norm:
+        return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, dim0=-2, dim1=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, dim0=-2, dim1=-1)], Tensor.gelu]
+      elif is_group_norm and mode == "default":
+        return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu]
+      elif is_group_norm and mode == "group_norm_masked":
+        return [make_conv(), partial(Tensor.dropout, p=dropout), GroupNormMasked(dim, dim, affine=True), Tensor.gelu]
+      else:
+        return [make_conv(), partial(Tensor.dropout, p=dropout), Tensor.gelu]
+    in_d, self.conv_layers, self.mode = 1, [], mode
+    for i, cl in enumerate(conv_layers):
+      assert len(cl) == 3, "invalid conv definition: " + str(cl)
+      (dim, k, stride) = cl
+      if i == 0: self.cl = cl
+      self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=(mode == "layer_norm"), is_group_norm=((mode == "default" or mode == "group_norm_masked") and i == 0), conv_bias=conv_bias))
+      in_d = dim
+  def __call__(self, x:Tensor, padding_mask:Tensor):
+    x = x.unsqueeze(1)  # BxT -> BxCxT
+    if self.mode == "group_norm_masked":
+      if padding_mask is not None:
+        _, k, stride = self.cl
+        lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1)  # ensure padding_mask is bool for tilde
+        lengths = (((lengths_org - k) / stride) + 1).floor().cast(dtypes.int64)
+        padding_mask = tilde(lengths_to_padding_mask(lengths)).cast(dtypes.int64)  # lengths_to_padding_mask returns bool tensor
+      x = self.conv_layers[0][0](x)  # padding_mask is numeric
+      x = self.conv_layers[0][1](x)
+      x = self.conv_layers[0][2](x, padding_mask)
+      x = self.conv_layers[0][3](x)
+    else:
+      x = x.sequential(self.conv_layers[0])  # default
+    for _, conv in enumerate(self.conv_layers[1:], start=1):
+      conv = reduce(lambda a,b: operator.iconcat(a,b if isinstance(b, list) else [b]), conv, [])  # flatten
+      x = x.sequential(conv)
+    return x
+
+class CondLayerNorm:  # https://github.com/auspicious3000/contentvec/blob/main/contentvec/modules/cond_layer_norm.py#L10
+  def __init__(self, dim_last, eps=1e-5, dim_spk=256, elementwise_affine=True):
+    self.dim_last, self.eps, self.dim_spk, self.elementwise_affine = dim_last, eps, dim_spk, elementwise_affine
+    if self.elementwise_affine:
+      self.weight_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
+      self.bias_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
+      self.weight_ln.weight, self.bias_ln.weight = (Tensor.ones(*self.weight_ln.weight.shape)), (Tensor.zeros(*self.bias_ln.weight.shape))
+  def __call__(self, x: Tensor, spk_emb: Tensor):
+    axis = tuple(-1-i for i in range(len(x.shape[1:])))
+    x = x.layernorm(axis=axis, eps=self.eps)
+    if not self.elementwise_affine: return x
+    weights, bias = self.weight_ln(spk_emb), self.bias_ln(spk_emb)
+    return weights * x + bias
+
+class GroupNormMasked:  # https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/modules/fp32_group_norm.py#L16
+  def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
+    self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
+    self.weight, self.bias = (Tensor.ones(num_channels)), (Tensor.zeros(num_channels)) if self.affine else (None, None)
+  def __call__(self, x:Tensor, mask:Tensor):
+    bsz, n_c, length = x.shape
+    assert n_c % self.num_groups == 0
+    x = x.reshape(bsz, self.num_groups, n_c // self.num_groups, length)
+    if mask is None: mask = Tensor.ones_like(x)
+    else: mask = mask.reshape(bsz, 1, 1, length)
+    x = x * mask
+    lengths = mask.sum(axis=3, keepdim=True)
+    assert x.shape[2] == 1
+    mean_ = x.mean(dim=3, keepdim=True)
+    mean = mean_ * length / lengths
+    var = (((x.std(axis=3, keepdim=True) ** 2) + mean_**2) * length / lengths - mean**2) + self.eps
+    return x.add(-mean).div(var.sqrt()).reshape(bsz, n_c, length).mul(self.weight.reshape(1,-1,1)).add(self.bias.reshape(1,-1,1))
+
+class Synthesizer:
+  def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, ssl_dim, n_speakers, sampling_rate=44100, vol_embedding=False, n_flow_layer=4, **kwargs):
+    self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.vol_embedding = spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, vol_embedding
+    self.emb_g = nn.Embedding(n_speakers, gin_channels)
+    if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels)
+    self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
+    self.enc_p = TextEncoder(inter_channels, hidden_channels, kernel_size, n_layers, filter_channels=filter_channels, n_heads=n_heads, p_dropout=p_dropout)
+    self.dec = Generator(sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels)
+    self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
+    self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels)
+    self.emb_uv = nn.Embedding(vocab_size=2, embed_size=hidden_channels)
+  def infer(self, c:Tensor, f0:Tensor, uv:Tensor, g:Tensor=None, noise_scale=0.35, seed=52468, vol=None) -> Tuple[Tensor, Tensor]:
+    Tensor.manual_seed(getenv('SEED', seed))
+    c_lengths = (Tensor.ones([c.shape[0]]) * c.shape[-1]).to(c.device)
+    if len(g.shape) == 1: g = g.unsqueeze(0)
+    g = self.emb_g(g).transpose(1, 2)
+    x_mask = sequence_mask(c_lengths, c.shape[2]).unsqueeze(1).cast(c.dtype)
+    vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0
+    x = self.pre(c) * x_mask + self.emb_uv(uv.cast(dtypes.int64)).transpose(1, 2) + vol
+    z_p, _, _, c_mask = self.enc_p.forward(x, x_mask, f0=self._f0_to_coarse(f0), noise_scale=noise_scale)
+    z = self.flow.forward(z_p, c_mask, g=g, reverse=True)
+    o = self.dec.forward(z * c_mask, g=g, f0=f0)
+    return o,f0
+  def _f0_to_coarse(self, f0 : Tensor):
+    f0_mel = 1127 * (1 + f0 / 700).log()
+    a = (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN)
+    b = F0_MEL_MIN * a - 1.
+    f0_mel = (f0_mel > 0).where(f0_mel * a - b, f0_mel)
+    f0_coarse = f0_mel.ceil().cast(dtype=dtypes.int64)
+    f0_coarse = f0_coarse * (f0_coarse > 0)
+    f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
+    f0_coarse = f0_coarse * (f0_coarse < F0_BIN)
+    f0_coarse = f0_coarse + ((f0_coarse >= F0_BIN) * (F0_BIN - 1))
+    return f0_coarse
+  @classmethod
+  def load_from_pretrained(cls, config_path:str, config_url:str, weights_path:str, weights_url:str) -> Synthesizer:
+    download_if_not_present(config_path, config_url)
+    hps = get_hparams_from_file(config_path)
+    download_if_not_present(weights_path, weights_url)
+    net_g = cls(hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model)
+    _ = load_checkpoint(weights_path, net_g, None, skip_list=["f0_decoder"])
+    logging.debug(f"{cls.__name__}:Loaded model with hps: {hps}")
+    return net_g, hps
+
+class TextEncoder:
+  def __init__(self, out_channels, hidden_channels, kernel_size, n_layers, gin_channels=0, filter_channels=None, n_heads=None, p_dropout=None):
+    self.out_channels, self.hidden_channels, self.kernel_size, self.n_layers, self.gin_channels = out_channels, hidden_channels, kernel_size, n_layers, gin_channels
+    self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+    self.f0_emb = nn.Embedding(256, hidden_channels)  # n_vocab = 256
+    self.enc_ = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
+  def forward(self, x, x_mask, f0=None, noise_scale=1):
+    x = x + self.f0_emb(f0).transpose(1, 2)
+    x = self.enc_.forward(x * x_mask, x_mask)
+    stats = self.proj(x) * x_mask
+    m, logs = split(stats, self.out_channels, dim=1)
+    z = (m + randn_like(m) * logs.exp() * noise_scale) * x_mask
+    return z, m, logs, x_mask
+
+class Upsample:
+  def __init__(self, scale_factor):
+    assert scale_factor % 1 == 0, "Only integer scale factor allowed."
+    self.scale = int(scale_factor)
+  def forward(self, x:Tensor):
+    repeats = tuple([1] * len(x.shape) + [self.scale])
+    new_shape = (*x.shape[:-1], x.shape[-1] * self.scale)
+    return x.unsqueeze(-1).repeat(repeats).reshape(new_shape)
+
+class SineGen:
+  def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voice_threshold=0, flag_for_pulse=False):
+    self.sine_amp, self.noise_std, self.harmonic_num, self.sampling_rate, self.voiced_threshold, self.flag_for_pulse = sine_amp, noise_std, harmonic_num, samp_rate, voice_threshold, flag_for_pulse
+    self.dim = self.harmonic_num + 1
+  def _f02uv(self, f0): return (f0 > self.voiced_threshold).float()  #generate uv signal
+  def _f02sine(self, f0_values):
+    def padDiff(x : Tensor): return (x.pad2d((0,0,-1,1)) - x).pad2d((0,0,0,-1))
+    def mod(x: Tensor, n: int) -> Tensor: return x - n * x.div(n).floor()  # this is what the % operator does in pytorch.
+    rad_values = mod((f0_values / self.sampling_rate) , 1)  # convert to F0 in rad
+    rand_ini = Tensor.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)  # initial phase noise
+
+    #rand_ini[:, 0] = 0
+    m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad2d((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool)
+    m = tilde(m)
+    rand_ini = m.where(rand_ini, 0)
+
+    #rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+    tmp = rad_values[:, 0, :] + rand_ini
+    m = Tensor.ones(tmp.shape).pad2d((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool)
+    m = tilde(m)
+    tmp = tmp.unsqueeze(1).pad2d((0,0,0,rad_values.shape[1]-1,0))
+    rad_values = m.where(rad_values, tmp)
+
+    tmp_over_one = mod(rad_values.cumsum(1), 1)
+    tmp_over_one_idx = padDiff(tmp_over_one) < 0
+    cumsum_shift = Tensor.zeros_like(rad_values)
+
+    #cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+    tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad2d((0,0,1,0))
+    cumsum_shift = tmp_over_one_idx
+
+    sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin()
+    return sines
+  def forward(self, f0, upp=None):
+    fn = f0.mul(Tensor([[range(1, self.harmonic_num + 2)]], dtype=dtypes.float32).to(f0.device))
+    sine_waves = self._f02sine(fn) * self.sine_amp  #generate sine waveforms
+    uv = self._f02uv(f0)  # generate uv signal
+    noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+    noise = noise_amp * randn_like(sine_waves)
+    sine_waves = sine_waves * uv + noise
+    return sine_waves, uv, noise
+
+class SourceHnNSF:
+  def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0):
+    self.sine_amp, self.noise_std = sine_amp, add_noise_std
+    self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
+    self.l_linear = nn.Linear(harmonic_num + 1, 1)
+  def forward(self, x, upp=None):
+    sine_waves, uv, _ = self.l_sin_gen.forward(x, upp)
+    sine_merge = self.l_linear(sine_waves.cast(self.l_linear.weight.dtype)).tanh()
+    noise = randn_like(uv) * self.sine_amp / 3
+    return sine_merge, noise, uv
+
+# most of the hifigan in standard vits is reused here, but need to upsample and construct harmonic source from f0
+class Generator:
+  def __init__(self, sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels):
+    self.sampling_rate, self.inter_channels, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.gin_channels = sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels
+    self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates)
+    self.conv_pre = nn.Conv1d(inter_channels, upsample_initial_channel, 7, 1, padding=3)
+    self.f0_upsamp = Upsample(scale_factor=np.prod(upsample_rates))
+    self.m_source = SourceHnNSF(sampling_rate, harmonic_num=8)
+    resblock = ResBlock1 if resblock == '1' else ResBlock2
+    self.ups, self.noise_convs, self.resblocks = [], [], []
+    for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+      c_cur = upsample_initial_channel//(2**(i+1))
+      self.ups.append(nn.ConvTranspose1d(upsample_initial_channel//(2**i), c_cur, k, u, padding=(k-u)//2))
+      stride_f0 = int(np.prod(upsample_rates[i + 1:]))
+      self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2) if (i + 1 < len(upsample_rates)) else nn.Conv1d(1, c_cur, kernel_size=1))
+    for i in range(len(self.ups)):
+      ch = upsample_initial_channel // (2 ** (i + 1))
+      for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+        self.resblocks.append(resblock(ch, k, d))
+    self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3)
+    if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+    self.upp = np.prod(upsample_rates)
+  def forward(self, x, f0, g=None):
+    f0 = self.f0_upsamp.forward(f0[:, None]).transpose(1, 2)  # bs,n,t
+    har_source, _, _ = self.m_source.forward(f0, self.upp)
+    har_source = har_source.transpose(1, 2)
+    x = self.conv_pre(x)
+    if g is not None:  x = x + self.cond(g)
+    for i in range(self.num_upsamples):
+      x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None
+      x_source = self.noise_convs[i](har_source)
+      x = x + x_source
+      for j in range(self.num_kernels):
+        if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
+        else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
+      x = xs / self.num_kernels
+    return self.conv_post(x.leakyrelu()).tanh()
+
+# **** helpers ****
+
+def randn_like(x:Tensor) -> Tensor: return Tensor.randn(*x.shape, dtype=x.dtype).to(device=x.device)
+
+def tilde(x: Tensor) -> Tensor:
+  if x.dtype == dtypes.bool: return (1 - x).cast(dtypes.bool)
+  return (x + 1) * -1  # this seems to be what the ~ operator does in pytorch for non bool
+
+def lengths_to_padding_mask(lens:Tensor) -> Tensor:
+  bsz, max_lens = lens.shape[0], lens.max().numpy().item()
+  mask = Tensor.arange(max_lens).to(lens.device).reshape(1, max_lens)
+  mask = mask.expand(bsz, -1) >= lens.reshape(bsz, 1).expand(-1, max_lens)
+  return mask.cast(dtypes.bool)
+
+def repeat_expand_2d_left(content, target_len): # content : [h, t]
+  src_len = content.shape[-1]
+  temp = np.arange(src_len+1) * target_len / src_len
+  current_pos, cols = 0, []
+  for i in range(target_len):
+    if i >= temp[current_pos+1]:
+      current_pos += 1
+    cols.append(content[:, current_pos])
+  return Tensor.stack(*cols).transpose(0, 1)
+
+def load_fairseq_cfg(checkpoint_path):
+  assert Path(checkpoint_path).is_file()
+  state = torch_load(checkpoint_path)
+  cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None
+  if cfg is None: raise RuntimeError(f"No cfg exist in state keys = {state.keys()}")
+  return HParams(**cfg)
+
+def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip_list=[]):
+  assert Path(checkpoint_path).is_file()
+  start_time = time.time()
+  checkpoint_dict = torch_load(checkpoint_path)
+  saved_state_dict = checkpoint_dict['model']
+  weight_g, weight_v, parent = None, None, None
+  for key, v in saved_state_dict.items():
+    if any(layer in key for layer in skip_list): continue
+    try:
+      obj, skip = model, False
+      for k in key.split('.'):
+        if k.isnumeric(): obj = obj[int(k)]
+        elif isinstance(obj, dict): obj = obj[k]
+        else:
+          if k in ["weight_g", "weight_v"]:
+            parent, skip = obj, True
+            if k == "weight_g": weight_g = v
+            else: weight_v = v
+          if not skip:
+            parent = obj
+            obj = getattr(obj, k)
+      if weight_g and weight_v:
+        setattr(obj, "weight_g", weight_g.numpy())
+        setattr(obj, "weight_v", weight_v.numpy())
+        obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0)
+        weight_g, weight_v, parent, skip = None, None, None, False
+      if not skip and obj.shape == v.shape:
+        if "feature_extractor" in key and (isinstance(parent, nn.GroupNorm) or isinstance(parent, nn.LayerNorm)):  # cast
+          obj.assign(v.to(obj.device).float())
+        else:
+          obj.assign(v.to(obj.device))
+      elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}")
+    except Exception as e: raise e
+  logging.info(f"Loaded checkpoint '{checkpoint_path}' in {time.time() - start_time:.4f}s")
+  return model, optimizer
+
+def pad_array(arr, target_length):
+  current_length = arr.shape[0]
+  if current_length >= target_length: return arr
+  pad_width = target_length - current_length
+  pad_left = pad_width // 2
+  pad_right = pad_width - pad_left
+  padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
+  return padded_arr
+
+def split_list_by_n(list_collection, n, pre=0):
+  for i in range(0, len(list_collection), n):
+    yield list_collection[i-pre if i-pre>=0 else i: i + n]
+
+def get_sid(spk2id:HParams, speaker:str) -> Tensor:
+  speaker_id = spk2id[speaker]
+  if not speaker_id and type(speaker) is int:
+    if len(spk2id.__dict__) >= speaker: speaker_id = speaker
+  if speaker_id is None: raise RuntimeError(f"speaker={speaker} not in the speaker list")
+  return Tensor([int(speaker_id)], dtype=dtypes.int64).unsqueeze(0)
+
+def get_encoder(ssl_dim) -> Type[SpeechEncoder]:
+  if ssl_dim == 256: return ContentVec256L9
+  if ssl_dim == 768: return ContentVec768L12
+
+#########################################################################################
+# CODE: https://github.com/svc-develop-team/so-vits-svc
+#########################################################################################
+# CONTENTVEC:
+#   CODE: https://github.com/auspicious3000/contentvec
+#   PAPER: https://arxiv.org/abs/2204.09224
+#########################################################################################
+# INSTALLATION: dependencies are for preprocessing and loading/saving audio.
+# pip3 install soundfile librosa praat-parselmouth
+#########################################################################################
+# EXAMPLE USAGE:
+# python3 examples/so_vits_svc.py --model tf2spy --file ~/recording.wav
+#########################################################################################
+# DEMO USAGE (uses audio sample from LJ-Speech):
+# python3 examples/so_vits_svc.py --model saul_goodman
+#########################################################################################
+SO_VITS_SVC_PATH = Path(__file__).parents[1] / "weights/So-VITS-SVC"
+VITS_MODELS = { # config_path, weights_path, config_url, weights_url
+  "saul_goodman" : (SO_VITS_SVC_PATH / "config_saul_gman.json", SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth"),
+  "drake" : (SO_VITS_SVC_PATH / "config_drake.json", SO_VITS_SVC_PATH / "pretrained_drake.pth", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth"),
+  "cartman" : (SO_VITS_SVC_PATH / "config_cartman.json", SO_VITS_SVC_PATH / "pretrained_cartman.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/G_10200.pth"),
+  "tf2spy" : (SO_VITS_SVC_PATH / "config_tf2spy.json", SO_VITS_SVC_PATH / "pretrained_tf2spy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/G_60000.pth"),
+  "tf2heavy" : (SO_VITS_SVC_PATH / "config_tf2heavy.json", SO_VITS_SVC_PATH / "pretrained_tf2heavy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/G_100000.pth"),
+  "lady_gaga" : (SO_VITS_SVC_PATH / "config_gaga.json", SO_VITS_SVC_PATH / "pretrained_gaga.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/G_14400.pth")
+}
+ENCODER_MODELS = { # weights_path, weights_url
+  "contentvec": (SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt")
+}
+ENCODER_MODEL = "contentvec"
+DEMO_PATH, DEMO_URL = Path(__file__).parents[1] / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
+if __name__=="__main__":
+  logging.basicConfig(stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG))
+  parser = argparse.ArgumentParser()
+  parser.add_argument("-m", "--model", default=None, help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", required=True)
+  parser.add_argument("-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file")
+  parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.")
+  parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.")
+  parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.")
+  parser.add_argument("--speaker", default=None, help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.")
+  parser.add_argument("--noise_scale", default=0.4)
+  parser.add_argument("--tran", default=0.0, help="Pitch shift, supports positive and negative (semitone) values. Default 0.0")
+  parser.add_argument("--pad_seconds", default=0.5)
+  parser.add_argument("--lg_num", default=0.0)
+  parser.add_argument("--clip_seconds", default=0.0)
+  parser.add_argument("--slice_db", default=-40)
+  args = parser.parse_args()
+
+  vits_model = args.model
+  encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model]
+
+  Tensor.no_grad, Tensor.training = True, False
+  # Get Synthesizer and ContentVec
+  net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3])
+  Encoder = get_encoder(hps.model.ssl_dim)
+  encoder = Encoder.load_from_pretrained(encoder_location[0], encoder_location[1])
+
+  # model config args
+  target_sample, spk2id, hop_length, target_sample = hps.data.sampling_rate, hps.spk, hps.data.hop_length, hps.data.sampling_rate
+  vol_embedding = hps.model.vol_embedding if hasattr(hps.data, "vol_embedding") and hps.model.vol_embedding is not None else False
+
+  # args
+  slice_db, clip_seconds, lg_num, pad_seconds, tran, noise_scale, audio_path = args.slice_db, args.clip_seconds, args.lg_num, args.pad_seconds, args.tran, args.noise_scale, args.file
+  speaker = args.speaker if args.speaker is not None else list(hps.spk.__dict__.keys())[0]
+
+  ### Loading audio and slicing ###
+  if audio_path == DEMO_PATH: download_if_not_present(DEMO_PATH, DEMO_URL)
+  assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav"
+  chunks = preprocess.cut(audio_path, db_thresh=slice_db)
+  audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks)
+
+  per_size = int(clip_seconds * audio_sr)
+  lg_size = int(lg_num * audio_sr)
+
+  ### Infer per slice ###
+  global_frame = 0
+  audio = []
+  for (slice_tag, data) in audio_data:
+    print(f"\n====segment start, {round(len(data) / audio_sr, 3)}s====")
+    length = int(np.ceil(len(data) / audio_sr * target_sample))
+
+    if slice_tag:
+      print("empty segment")
+      _audio = np.zeros(length)
+      audio.extend(list(pad_array(_audio, length)))
+      global_frame += length // hop_length
+      continue
+
+    datas = [data] if per_size == 0 else split_list_by_n(data, per_size, lg_size)
+
+    for k, dat in enumerate(datas):
+      per_length = int(np.ceil(len(dat) / audio_sr * target_sample)) if clip_seconds!=0 else length
+      pad_len = int(audio_sr * pad_seconds)
+      dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
+      raw_path = io.BytesIO()
+      soundfile.write(raw_path, dat, audio_sr, format="wav")
+      raw_path.seek(0)
+
+      ### Infer START ###
+      wav, sr = preprocess.load_audiofile(raw_path)
+      wav = preprocess.sinc_interp_resample(wav, sr, target_sample)[0]
+      wav16k, f0, uv = preprocess.get_unit_f0(wav, tran, hop_length, target_sample)
+      sid = get_sid(spk2id, speaker)
+      n_frames = f0.shape[1]
+
+      # ContentVec infer
+      start = time.time()
+      c = encoder.encode(wav16k)
+      c = repeat_expand_2d_left(c.squeeze(0).realize(), f0.shape[1])  # interpolate speech encoding to match f0
+      c = c.unsqueeze(0).realize()
+      enc_time = time.time() - start
+
+      # VITS infer
+      vits_start = time.time()
+      out_audio, f0 = net_g.infer(c, f0=f0, uv=uv, g=sid, noise_scale=noise_scale, vol=None)
+      out_audio = out_audio[0,0].float().realize()
+      vits_time = time.time() - vits_start
+
+      infer_time = time.time() - start
+      logging.info("total infer time:{:.2f}s, speech_enc time:{:.2f}s, vits time:{:.2f}s".format(infer_time, enc_time, vits_time))
+      ### Infer END ###
+
+      out_sr, out_frame = out_audio.shape[-1], n_frames
+      global_frame += out_frame
+      _audio = out_audio.numpy()
+      pad_len = int(target_sample * pad_seconds)
+      _audio = _audio[pad_len:-pad_len]
+      _audio = pad_array(_audio, per_length)
+      audio.extend(list(_audio))
+
+  audio = np.array(audio)
+  out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model}{f'_spk_{speaker}'}_{args.base_name}.wav")
+  out_path.parent.mkdir(parents=True, exist_ok=True)
+  soundfile.write(out_path, audio, target_sample, format="flac")
+  logging.info(f"Saved audio output to {out_path}")

+ 204 - 0
tinychat/examples/sovits_helpers/preprocess.py

@@ -0,0 +1,204 @@
+import math
+from typing import Optional, Tuple
+from tinygrad import Tensor, dtypes
+import librosa
+import soundfile
+import numpy as np
+import parselmouth
+
+class PMF0Predictor:  # from https://github.com/svc-develop-team/so-vits-svc/
+  def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
+    self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm"
+  def interpolate_f0(self,f0):
+    vuv_vector = np.zeros_like(f0, dtype=np.float32)
+    vuv_vector[f0 > 0.0] = 1.0
+    vuv_vector[f0 <= 0.0] = 0.0
+    nzindex = np.nonzero(f0)[0]
+    data = f0[nzindex]
+    nzindex = nzindex.astype(np.float32)
+    time_org = self.hop_length / self.sampling_rate * nzindex
+    time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
+    if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
+    if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
+    f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
+    return f0,vuv_vector
+  def compute_f0(self,wav,p_len=None):
+    x = wav
+    if p_len is None: p_len = x.shape[0]//self.hop_length
+    else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
+    time_step = self.hop_length / self.sampling_rate * 1000
+    f0 = parselmouth.Sound(x, self.sampling_rate) \
+                    .to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \
+                    .selected_array['frequency']
+    pad_size=(p_len - len(f0) + 1) // 2
+    if(pad_size>0 or p_len - len(f0) - pad_size>0):
+      f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
+    f0,uv = self.interpolate_f0(f0)
+    return f0
+  def compute_f0_uv(self,wav,p_len=None):
+    x = wav
+    if p_len is None: p_len = x.shape[0]//self.hop_length
+    else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
+    time_step = self.hop_length / self.sampling_rate * 1000
+    f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
+      time_step=time_step / 1000, voicing_threshold=0.6,
+      pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
+    pad_size=(p_len - len(f0) + 1) // 2
+    if(pad_size>0 or p_len - len(f0) - pad_size>0):
+      f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
+    f0,uv = self.interpolate_f0(f0)
+    return f0,uv
+
+class Slicer:  # from https://github.com/svc-develop-team/so-vits-svc/
+  def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000):
+    if not min_length >= min_interval >= hop_size:
+      raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
+    if not max_sil_kept >= hop_size:
+      raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
+    min_interval = sr * min_interval / 1000
+    self.threshold = 10 ** (threshold / 20.)
+    self.hop_size = round(sr * hop_size / 1000)
+    self.win_size = min(round(min_interval), 4 * self.hop_size)
+    self.min_length = round(sr * min_length / 1000 / self.hop_size)
+    self.min_interval = round(min_interval / self.hop_size)
+    self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
+  def _apply_slice(self, waveform, begin, end):
+    if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
+    else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
+  def slice(self, waveform):
+    samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
+    if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
+    rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
+    sil_tags, silence_start, clip_start = [], None, 0
+    for i, rms in enumerate(rms_list):
+      if rms < self.threshold:  # Keep looping while frame is silent.
+        if silence_start is None:  # Record start of silent frames.
+          silence_start = i
+        continue
+      if silence_start is None: continue  # Keep looping while frame is not silent and silence start has not been recorded.
+      # Clear recorded silence start if interval is not enough or clip is too short
+      is_leading_silence = silence_start == 0 and i > self.max_sil_kept
+      need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
+      if not is_leading_silence and not need_slice_middle:
+        silence_start = None
+        continue
+      if i - silence_start <= self.max_sil_kept:  # Need slicing. Record the range of silent frames to be removed.
+        pos = rms_list[silence_start: i + 1].argmin() + silence_start
+        sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
+        clip_start = pos
+      elif i - silence_start <= self.max_sil_kept * 2:
+        pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
+        pos += i - self.max_sil_kept
+        pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
+        pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
+        if silence_start == 0:
+          sil_tags.append((0, pos_r))
+          clip_start = pos_r
+        else:
+          sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
+          clip_start = max(pos_r, pos)
+      else:
+        pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
+        pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
+        sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
+        clip_start = pos_r
+      silence_start = None
+    total_frames = rms_list.shape[0]
+    if silence_start is not None and total_frames - silence_start >= self.min_interval:  # Deal with trailing silence.
+      silence_end = min(total_frames, silence_start + self.max_sil_kept)
+      pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
+      sil_tags.append((pos, total_frames + 1))
+    if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}  # Apply and return slices.
+    chunks = []
+    if sil_tags[0][0]:
+      chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
+    for i in range(0, len(sil_tags)):
+      if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
+      chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
+    if sil_tags[-1][1] * self.hop_size < len(waveform):
+      chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
+    chunk_dict = {}
+    for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i]
+    return chunk_dict
+
+# sinc_interp_hann audio resampling
+class Resample:
+  def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None):
+    self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta
+    self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
+    self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None)
+  def __call__(self, waveform:Tensor) -> Tensor:
+    if self.orig_freq == self.new_freq: return waveform
+    return self._apply_sinc_resample_kernel(waveform)
+  def _apply_sinc_resample_kernel(self, waveform:Tensor):
+    if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.")
+    orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
+    shape = waveform.shape
+    waveform = waveform.reshape(-1, shape[-1])  # pack batch
+    num_wavs, length = waveform.shape
+    target_length = int(math.ceil(new_freq * length / orig_freq))
+    waveform = waveform.pad2d((self.width, self.width + orig_freq))
+    resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq)
+    resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
+    resampled = resampled[..., :target_length]
+    resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:])  # unpack batch
+    return resampled
+  def _get_sinc_resample_kernel(self, dtype=None):
+    orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
+    if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.")
+    base_freq = min(orig_freq, new_freq)
+    base_freq *= self.rolloff
+    width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
+    idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq
+    t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
+    t *= base_freq
+    t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
+    window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
+    t *= math.pi
+    scale = base_freq / orig_freq
+    kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t)
+    kernels *= window * scale
+    if dtype is None: kernels = kernels.cast(dtype=dtypes.float32)
+    return kernels, width
+
+def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None):
+  resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
+  return resamp(x)
+
+def cut(audio_path, db_thresh=-30, min_len=5000):
+  audio, sr = librosa.load(audio_path, sr=None)
+  slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
+  chunks = slicer.slice(audio)
+  return chunks
+
+def chunks2audio(audio_path, chunks):
+  chunks = dict(chunks)
+  audio, sr = load_audiofile(audio_path)
+  if len(audio.shape) == 2 and audio.shape[1] >= 2:
+    audio = audio.mean(0).unsqueeze(0)
+  audio = audio.numpy()[0]
+  result = []
+  for k, v in chunks.items():
+    tag = v["split_time"].split(",")
+    if tag[0] != tag[1]:
+      result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
+  return result, sr
+
+def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True):
+  with soundfile.SoundFile(filepath, "r") as file_:
+    frames = file_._prepare_read(frame_offset, None, num_frames)
+    waveform = file_.read(frames, "float32", always_2d=True)
+    sample_rate = file_.samplerate
+  waveform = Tensor(waveform)
+  if channels_first: waveform = waveform.transpose(0, 1)
+  return waveform, sample_rate
+
+def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]:
+  f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
+  f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
+  if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected")
+  f0 = Tensor(f0.astype(np.float32)).float()
+  f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
+  uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
+  wav16k = sinc_interp_resample(wav[None,:], target_sample, 16000)[0]
+  return wav16k.realize(), f0.realize(), uv.realize()

+ 294 - 0
tinychat/examples/stable_diffusion.py

@@ -0,0 +1,294 @@
+# https://arxiv.org/pdf/2112.10752.pdf
+# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
+import tempfile
+from pathlib import Path
+import argparse
+from collections import namedtuple
+from typing import Dict, Any
+
+from PIL import Image
+import numpy as np
+from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
+from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
+from tinygrad.nn import Conv2d, GroupNorm
+from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
+from extra.models.clip import Closed, Tokenizer
+from extra.models.unet import UNetModel
+
+class AttnBlock:
+  def __init__(self, in_channels):
+    self.norm = GroupNorm(32, in_channels)
+    self.q = Conv2d(in_channels, in_channels, 1)
+    self.k = Conv2d(in_channels, in_channels, 1)
+    self.v = Conv2d(in_channels, in_channels, 1)
+    self.proj_out = Conv2d(in_channels, in_channels, 1)
+
+  # copied from AttnBlock in ldm repo
+  def __call__(self, x):
+    h_ = self.norm(x)
+    q,k,v = self.q(h_), self.k(h_), self.v(h_)
+
+    # compute attention
+    b,c,h,w = q.shape
+    q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)]
+    h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
+    return x + self.proj_out(h_)
+
+class ResnetBlock:
+  def __init__(self, in_channels, out_channels=None):
+    self.norm1 = GroupNorm(32, in_channels)
+    self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
+    self.norm2 = GroupNorm(32, out_channels)
+    self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
+    self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
+
+  def __call__(self, x):
+    h = self.conv1(self.norm1(x).swish())
+    h = self.conv2(self.norm2(h).swish())
+    return self.nin_shortcut(x) + h
+
+class Mid:
+  def __init__(self, block_in):
+    self.block_1 = ResnetBlock(block_in, block_in)
+    self.attn_1 = AttnBlock(block_in)
+    self.block_2 = ResnetBlock(block_in, block_in)
+
+  def __call__(self, x):
+    return x.sequential([self.block_1, self.attn_1, self.block_2])
+
+class Decoder:
+  def __init__(self):
+    sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
+    self.conv_in = Conv2d(4,512,3, padding=1)
+    self.mid = Mid(512)
+
+    arr = []
+    for i,s in enumerate(sz):
+      arr.append({"block":
+        [ResnetBlock(s[1], s[0]),
+         ResnetBlock(s[0], s[0]),
+         ResnetBlock(s[0], s[0])]})
+      if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
+    self.up = arr
+
+    self.norm_out = GroupNorm(32, 128)
+    self.conv_out = Conv2d(128, 3, 3, padding=1)
+
+  def __call__(self, x):
+    x = self.conv_in(x)
+    x = self.mid(x)
+
+    for l in self.up[::-1]:
+      print("decode", x.shape)
+      for b in l['block']: x = b(x)
+      if 'upsample' in l:
+        # https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
+        bs,c,py,px = x.shape
+        x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
+        x = l['upsample']['conv'](x)
+      x.realize()
+
+    return self.conv_out(self.norm_out(x).swish())
+
+class Encoder:
+  def __init__(self):
+    sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
+    self.conv_in = Conv2d(3,128,3, padding=1)
+
+    arr = []
+    for i,s in enumerate(sz):
+      arr.append({"block":
+        [ResnetBlock(s[0], s[1]),
+         ResnetBlock(s[1], s[1])]})
+      if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
+    self.down = arr
+
+    self.mid = Mid(512)
+    self.norm_out = GroupNorm(32, 512)
+    self.conv_out = Conv2d(512, 8, 3, padding=1)
+
+  def __call__(self, x):
+    x = self.conv_in(x)
+
+    for l in self.down:
+      print("encode", x.shape)
+      for b in l['block']: x = b(x)
+      if 'downsample' in l: x = l['downsample']['conv'](x)
+
+    x = self.mid(x)
+    return self.conv_out(self.norm_out(x).swish())
+
+class AutoencoderKL:
+  def __init__(self):
+    self.encoder = Encoder()
+    self.decoder = Decoder()
+    self.quant_conv = Conv2d(8, 8, 1)
+    self.post_quant_conv = Conv2d(4, 4, 1)
+
+  def __call__(self, x):
+    latent = self.encoder(x)
+    latent = self.quant_conv(latent)
+    latent = latent[:, 0:4]  # only the means
+    print("latent", latent.shape)
+    latent = self.post_quant_conv(latent)
+    return self.decoder(latent)
+
+def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
+  betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
+  alphas = 1.0 - betas
+  alphas_cumprod = np.cumprod(alphas, axis=0)
+  return Tensor(alphas_cumprod)
+
+unet_params: Dict[str,Any] = {
+  "adm_in_ch": None,
+  "in_ch": 4,
+  "out_ch": 4,
+  "model_ch": 320,
+  "attention_resolutions": [4, 2, 1],
+  "num_res_blocks": 2,
+  "channel_mult": [1, 2, 4, 4],
+  "n_heads": 8,
+  "transformer_depth": [1, 1, 1, 1],
+  "ctx_dim": 768,
+  "use_linear": False,
+}
+
+class StableDiffusion:
+  def __init__(self):
+    self.alphas_cumprod = get_alphas_cumprod()
+    self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
+    self.first_stage_model = AutoencoderKL()
+    self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
+
+  def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
+    temperature = 1
+    sigma_t = 0
+    sqrt_one_minus_at = (1-a_t).sqrt()
+    #print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
+
+    pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+    # direction pointing to x_t
+    dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+
+    x_prev = a_prev.sqrt() * pred_x0 + dir_xt
+    return x_prev, pred_x0
+
+  def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
+    # put into diffuser
+    latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
+    unconditional_latent, latent = latents[0:1], latents[1:2]
+
+    e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
+    return e_t
+
+  def decode(self, x):
+    x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
+    x = self.first_stage_model.decoder(x)
+
+    # make image correct size and scale
+    x = (x + 1.0) / 2.0
+    x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
+    return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
+
+  def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
+    e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
+    x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
+    #e_t_next = get_model_output(x_prev)
+    #e_t_prime = (e_t + e_t_next) / 2
+    #x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
+    return x_prev.realize()
+
+# ** ldm.models.autoencoder.AutoencoderKL (done!)
+# 3x512x512 <--> 4x64x64 (16384)
+# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
+# section 4.3 of paper
+# first_stage_model.encoder, first_stage_model.decoder
+
+# ** ldm.modules.diffusionmodules.openaimodel.UNetModel
+# this is what runs each time to sample. is this the LDM?
+# input:  4x64x64
+# output: 4x64x64
+# model.diffusion_model
+# it has attention?
+
+# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
+# cond_stage_model.transformer.text_model
+
+if __name__ == "__main__":
+  default_prompt = "a horse sized cat eating a bagel"
+  parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+  parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
+  parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to render")
+  parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
+  parser.add_argument('--noshow', action='store_true', help="Don't show the image")
+  parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
+  parser.add_argument('--timing', action='store_true', help="Print timing per step")
+  parser.add_argument('--seed', type=int, help="Set the random latent seed")
+  parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
+  args = parser.parse_args()
+
+  Tensor.no_grad = True
+  model = StableDiffusion()
+
+  # load in weights
+  load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
+
+  if args.fp16:
+    for k,v in get_state_dict(model).items():
+      if k.startswith("model"):
+        v.replace(v.cast(dtypes.float16).realize())
+
+  # run through CLIP to get context
+  tokenizer = Tokenizer.ClipTokenizer()
+  prompt = Tensor([tokenizer.encode(args.prompt)])
+  context = model.cond_stage_model.transformer.text_model(prompt).realize()
+  print("got CLIP context", context.shape)
+
+  prompt = Tensor([tokenizer.encode("")])
+  unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
+  print("got unconditional CLIP context", unconditional_context.shape)
+
+  # done with clip model
+  del model.cond_stage_model
+
+  timesteps = list(range(1, 1000, 1000//args.steps))
+  print(f"running for {timesteps} timesteps")
+  alphas = model.alphas_cumprod[Tensor(timesteps)]
+  alphas_prev = Tensor([1.0]).cat(alphas[:-1])
+
+  # start with random noise
+  if args.seed is not None: Tensor.manual_seed(args.seed)
+  latent = Tensor.randn(1,4,64,64)
+
+  @TinyJit
+  def run(model, *x): return model(*x).realize()
+
+  # this is diffusion
+  with Context(BEAM=getenv("LATEBEAM")):
+    for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
+      GlobalCounters.reset()
+      t.set_description("%3d %3d" % (index, timestep))
+      with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
+        tid = Tensor([index])
+        latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
+        if args.timing: Device[Device.DEFAULT].synchronize()
+    del run
+
+  # upsample latent space to image with autoencoder
+  x = model.decode(latent)
+  print(x.shape)
+
+  # save image
+  im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
+  print(f"saving {args.out}")
+  im.save(args.out)
+  # Open image.
+  if not args.noshow: im.show()
+
+  # validation!
+  if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5:
+    ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
+    distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
+    assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")
+    print(colored(f"output validated with {distance=}", "green"))

Algúns arquivos non se mostraron porque demasiados arquivos cambiaron neste cambio