-
Notifications
You must be signed in to change notification settings - Fork 3
/
mistral_lora.py
51 lines (43 loc) · 1.61 KB
/
mistral_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import sys
import anyio
import dagger
import os
import time
import subprocess
import urllib.request
import zipfile
import textwrap
import yaml
IMAGE = "quay.io/lukemarsden/axolotl:v0.0.1"
PROMPT = "If I put up a hammock hung between opposite sides of a round lake, go to sleep in the hammock and fall out, where will I land?"
async def main():
print("Spawning docker socket forwarder...")
p = subprocess.Popen(["socat", "TCP-LISTEN:12345,reuseaddr,fork,bind=172.17.0.1", "UNIX-CONNECT:/var/run/docker.sock"])
time.sleep(1)
print("Done!")
config = dagger.Config(log_output=sys.stdout)
async with dagger.Connection(config) as client:
try:
python = (
client
.container()
.from_("docker:latest") # TODO: use '@sha256:...'
# break cache
# .with_env_variable("BREAK_CACHE", str(time.time()))
.with_entrypoint("/usr/local/bin/docker")
.with_exec(["-H", "tcp://172.17.0.1:12345",
"run", "-i", "--rm", "--gpus", "all",
IMAGE,
"bash", "-c", "echo "{PROMPT}" |python -u -m axolotl.cli.inference examples/mistral/qlora-instruct.yml",
])
)
# execute
err = await python.stderr()
out = await python.stdout()
# print stderr
print(f"Question: {PROMPT}\n\nAnswer: {out}")
except Exception as e:
import pdb; pdb.set_trace()
print(f"error: {e}")
p.terminate()
anyio.run(main)