This write up is going to demonstrate that while it looks that a mmap
‘ed file IO looks like it’s leaking memory it actually is not.
HuggingFace’s datasets project uses MMAP to make datasets available to multiple processes in an efficient way. This is very important since typically a machine learning training program will use a Dataloader which may use multiple workers, or alternatively the same dataset is simply accessed by multiple processes.
An issue was posted that suggested that a datasets
-based program leaks memory with each iteration. This triggered an extensive research into understanding that MMAP doesn’t leak memory and bringing a lot of deepeer understanding of the different components used under the hood of datasets
.
If you’d like to gain a deeper understanding into why and how please read on.
Table of Contents
Emulating a computer with just 1GB of memory
Since we don’t want to crash our computer while debugging memory issues we are going to emulate a computer with just 1GB of memory and no swap memory. Unless such computer has a protection from programs using more memory than the computer has most of the time such computers start thrashing and eventually crash.
To accomplish that we are going to start a cgroups-controlled shell which will kill any program started from that shell and which consumes more than 1GB of memory (and give it no swap memory either):
$ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=0G --setenv="MEMLIMIT=1GB" bash
I’m setting MEMLIMIT=1GB
env variable so that at any moment I can check if I’m in the right shell by printing:
$ echo $MEMLIMIT
1GB
Let’s validate that this shell allows a program to allocate under 1GB of RSS RAM, but will kill it if it tries to allocate more than that:
# 7 * 128M chars
$ python -c "import sys, os, psutil; a='a'*7*2**27; print(f'{psutil.Process(os.getpid()).memory_info().rss >> 20}MB');"
908MB
# 8 * 128M chars
$ python -c "import sys, os, psutil; a='a'*8*2**27; print(f'{psutil.Process(os.getpid()).memory_info().rss >> 20}MB');"
Killed
So we can see that < ~1GB works, but beyond an allocation that asks for more than 1GB of resident memory gets killed.
In the rest of this write up let’s use shell A, which is unlimited (or rather limited to an actual available memory on your computer) and shell B, where a program started from it can only allocate 1GB of resident memory.
Sidenote: Linux memory management and reporting is super-complicated and one could probably easily write a whole book about it. Resident Set Size (RSS) is typically the easiest to use to measure the approximate actual memory usage by the program. It doesn’t tell you the whole truth, but most of the time it’s good enough to detect memory leaks. Therefore in this write up this is the metric we are going to use.
Simple IO debug program
Now let’s write a simple debug program that will create a file with a few very large lines, and then it’ll read them sequentially using a normal IO, but if we set --mmap
it’ll switch to memory mmaped API via the mmap
module.
Additionally, if --accumulate
flag is passed the program will accumulate the lines it reads into a single string.
$ cat python mmap-no-leak-debug.py
import gc
import mmap
import os
import psutil
import sys
PATH = "./tmp.txt"
# create a large data file with a few long lines
if not os.path.exists(PATH):
with open(PATH, "w") as fh:
s = 'a'* 2**27 + "\n" # 128MB
# write ~2GB file
for i in range(16):
fh.write(s)
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
print(f"{'idx':>4} {'RSS':>10} {'Δ RSS':>12} {'Δ accumulated':>10}")
content = ''
mem_after = mem_before_acc = mem_after_acc = mem_before = proc.memory_info().rss / 2**20
print(f"{0:4d} {mem_after:10.2f}MB {mem_after - 0:10.2f}MB {0:10.2f}MB")
mmap_mode = True if "--mmap" in sys.argv else False
with open(PATH, "r") as fh:
if mmap_mode:
mm = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)
idx = 0
while True:
idx += 1
mem_before = mem_read()
line = mm.readline() if mmap_mode else fh.readline()
if not line:
break
if "--accumulate" in sys.argv:
mem_before_acc = mem_read()
content += str(line)
mem_after_acc = mem_read()
mem_after = mem_read()
print(f"{idx:4d} {mem_after:10.2f}MB {mem_after - mem_before:10.2f}MB {mem_after_acc - mem_before_acc:10.2f}MB")
The four output columns are:
idx RSS Δ RSS Δ accumulated
- the line number (starting from 1)
- the total RSS reported at the end of each iteration
- the RSS delta of each iteration
- the accumulated buffer delta
And as you can see we force Python’s garbage collection via gc.collect()
before taking RSS (Resident Set Size) measurements. This is a very crucial step when debugging memory usages and leaks in particular and especially if you delete some objects and want to make sure that memory is actually freed as Python’s garbage collection mechanism is not immediate.
Normal IO diagnostics
First, let’s run normal IO without accumulating any strings and simply discarding those.
shell A $ python mmap-no-leak-debug.py
idx RSS Δ RSS Δ accumulated
0 12.37MB 12.37MB 0.00MB
1 269.66MB 257.29MB 0.00MB
2 269.68MB 0.02MB 0.00MB
3 269.68MB 0.00MB 0.00MB
4 269.69MB 0.01MB 0.00MB
5 269.69MB 0.00MB 0.00MB
6 269.70MB 0.01MB 0.00MB
7 269.70MB 0.00MB 0.00MB
8 269.70MB 0.01MB 0.00MB
9 269.70MB 0.00MB 0.00MB
10 269.71MB 0.01MB 0.00MB
11 269.71MB 0.00MB 0.00MB
12 269.71MB 0.00MB 0.00MB
13 269.71MB 0.00MB 0.00MB
14 269.71MB 0.00MB 0.00MB
15 269.71MB 0.00MB 0.00MB
16 145.96MB -123.75MB 0.00MB
We read in a loop a 128MB line and discard it.
We can see the memory is very low and steady, with some fluctuations when Python decided to release some memory. The program allocates more than 128MB due to a new line character in the string – this is a peculiar Python behavior.
The bottom line is that the program doesn’t appear to be leaking any memory.
MMAP’ed IO diagnostics
Now let’s do the exact same operation but this time using mmap
‘s IO:
shell A $ python mmap-no-leak-debug.py --mmap
idx RSS Δ RSS Δ accumulated
0 12.39MB 12.39MB 0.00MB
1 268.25MB 255.87MB 0.00MB
2 396.47MB 128.22MB 0.00MB
3 524.47MB 128.00MB 0.00MB
4 652.47MB 128.00MB 0.00MB
5 780.47MB 128.00MB 0.00MB
6 908.47MB 128.00MB 0.00MB
7 1036.47MB 128.00MB 0.00MB
8 1164.47MB 128.00MB 0.00MB
9 1292.47MB 128.00MB 0.00MB
10 1420.47MB 128.00MB 0.00MB
11 1548.47MB 128.00MB 0.00MB
12 1676.47MB 128.00MB 0.00MB
13 1804.47MB 128.00MB 0.00MB
14 1932.47MB 128.00MB 0.00MB
15 2060.47MB 128.00MB 0.00MB
16 2188.47MB 128.00MB 0.00MB
Whoah! It looks like there is a major leak here. On each iteration the program keeps on growing by 128MB despite us discarding the read data. What’s going on?
The theoretical explanation is simple – MMAP was designed to make IO faster and shared by multiple processes – so if there is a lot of available RAM, the MMAP API will use as much of it as it can and in order to speed things up it won’t normally release it back to the OS. For example, if you have two programs reading the same sections from the same MMAP’ed file only the first program will incur the delay of copying the data from disc to RAM. The other program will read it directly from RAM. Since MMAP doesn’t know which sections will be accessed next it simply keeps everything it read in the memory if there is enough of it.
But you’d say this is very bad and that’s a terrible design. But wait, it only keeps it in memory if nobody else wants the memory, and it immediately releases that unused memory back to the operating system as soon as such demand arises.
Proof that there is no leak
To show that the memory does get released as soon as it’s needed let’s re-run this same program in shell B, where only 1GB of memory is allowed to be allocated.
shell B $ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=0G --setenv="MEMLIMIT=1GB" bash
shell B $ python mmap-no-leak-debug.py --mmap
idx RSS Δ RSS Δ accumulated
0 12.48MB 12.48MB 0.00MB
1 268.51MB 256.03MB 0.00MB
2 396.73MB 128.22MB 0.00MB
3 524.73MB 128.00MB 0.00MB
4 652.73MB 128.00MB 0.00MB
5 780.73MB 128.00MB 0.00MB
6 908.73MB 128.00MB 0.00MB
7 1036.73MB 128.00MB 0.00MB
8 1164.73MB 128.00MB 0.00MB
9 1292.73MB 128.00MB 0.00MB
10 1420.73MB 128.00MB 0.00MB
11 1548.73MB 128.00MB 0.00MB
12 1676.73MB 128.00MB 0.00MB
13 1804.73MB 128.00MB 0.00MB
14 1932.73MB 128.00MB 0.00MB
15 2060.73MB 128.00MB 0.00MB
16 2188.69MB 127.95MB 0.00MB
A surprise, it appears that the program managed to allocate >2GB of memory when we double checked that it should have been killed as soon as it reached 1GB RSS since we limited the shell to allow only <1GB memory allocation!
We will understand better shortly what’s going on, but it’s clear that cgroups that controls the memory usage is aware that while it accounts that MMAP’ed memory to the RSS counter of the program it’s aware that the program itself isn’t using most of this memory!
Interim observation: we can’t rely on RSS memory stats to diagnose memory leaks when MMAP is used.
Let’s create memory pressure
This is where our --accumulate
flag comes in. It’s going to help us to see that RSS is “misreporting” the actual memory used by the program.
First we run it with normal IO:
shell A $ python mmap-no-leak-debug.py --accumulate
idx RSS Δ RSS Δ accumulated
0 12.30MB 12.30MB 0.00MB
1 269.60MB 257.29MB 0.00MB
2 525.49MB 255.89MB 127.93MB
3 653.49MB 128.00MB 127.87MB
4 781.50MB 128.01MB 127.87MB
5 909.50MB 128.00MB 127.87MB
6 1037.51MB 128.01MB 127.87MB
7 1165.51MB 128.00MB 127.87MB
8 1293.52MB 128.01MB 127.87MB
9 1421.52MB 128.00MB 127.87MB
10 1549.53MB 128.01MB 127.87MB
11 1677.53MB 128.00MB 127.87MB
12 1805.53MB 128.00MB 127.87MB
13 1933.53MB 128.00MB 127.87MB
14 2061.53MB 128.00MB 127.87MB
15 2189.53MB 128.00MB 127.87MB
16 2193.78MB 4.25MB 127.87MB
where RSS reports correctly 128*16 ~= 2048
MB and then some for the other bits of the program, but the ballpark matches.
Now let’s activate MMAP and re-run:
shell A $ python mmap-no-leak-debug.py --mmap --accumulate
idx RSS Δ RSS Δ accumulated
0 12.37MB 12.37MB 0.00MB
1 396.39MB 384.02MB 128.13MB
2 652.48MB 256.09MB 128.00MB
3 908.48MB 256.00MB 128.00MB
4 1164.48MB 256.00MB 128.00MB
5 1420.48MB 256.00MB 128.00MB
6 1676.48MB 256.00MB 128.00MB
7 1932.48MB 256.00MB 128.00MB
8 2188.48MB 256.00MB 128.00MB
9 2444.48MB 256.00MB 128.00MB
10 2700.48MB 256.00MB 128.00MB
11 2956.48MB 256.00MB 128.00MB
12 3212.48MB 256.00MB 128.00MB
13 3468.48MB 256.00MB 128.00MB
14 3724.48MB 256.00MB 128.00MB
15 3980.48MB 256.00MB 128.00MB
16 4236.46MB 255.98MB 128.00MB
Here we can see that RSS reports 2x memory than it actually uses.
And now let’s create pressure using our 1GB-limited shell B and use normal IO with accumulation:
shell B $ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=0G --setenv="MEMLIMIT=1GB" bash
shell B $ python mmap-no-leak-debug.py --accumulate
idx RSS Δ RSS Δ accumulated
0 12.38MB 12.38MB 0.00MB
1 269.41MB 257.04MB 0.00MB
2 525.55MB 256.14MB 127.93MB
3 653.55MB 128.00MB 127.87MB
4 781.56MB 128.01MB 127.87MB
5 909.56MB 128.00MB 127.87MB
Killed
As you can easily see the program gets killed once it reaches 1GB of RSS. It managed to perform 5 iterations, thus on iteration 6 it tries to accumulate 6*128=768
plus the current readline
read of 128MB, plus the memory used by the rest of the program, it crosses 1GB and gets killed before finishing iteration 6.
Also it might be useful to compare with the same run with shell A. You can see that RSS of the shell B run is quite different from shell A. The reported RSS doesn’t grow as fast.
Now let’s run the MMAPed version:
shell B $ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=0G --setenv="MEMLIMIT=1GB" bash
shell B $ python mmap-no-leak-debug.py --mmap --accumulate
idx RSS Δ RSS Δ accumulated
0 12.51MB 12.51MB 0.00MB
1 396.52MB 384.00MB 128.13MB
2 652.60MB 256.08MB 128.00MB
3 908.60MB 256.00MB 128.00MB
4 1164.60MB 256.00MB 128.00MB
5 1420.60MB 256.00MB 128.00MB
Killed
You can see it gets killed in the exactly same iteration as when it was run without MMAP.
You can see that while the RSS numbers are bigger than that of the normal IO run, the program gets killed in the exact same iteration. which tells us the actual memory usage with normal IO and mmap’ed IO is either very similar or very likely exactly the same.
What about PyArrow?
Originally this whole research started from this Issue in the datasets
repo. It looked like a dataset loaded via pyarrow
leaked on every iteration.
Quentin Lhoest reduced it to a simple pyarrow
program
$ cat mmap-no-leak-debug-pyarrow.py
import psutil
import os
import gc
import pyarrow as pa
ARROW_PATH = "tmp.arrow"
if not os.path.exists(ARROW_PATH):
arr = pa.array([b"a" * (200 * 1024)] * 1000) # ~200MB
table = pa.table({"a": arr})
with open(ARROW_PATH, "wb") as f:
writer = pa.RecordBatchStreamWriter(f, schema=table.schema)
writer.write_table(table)
writer.close()
def memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
memory_mapped_stream = pa.memory_map(filename)
opened_stream = pa.ipc.open_stream(memory_mapped_stream)
pa_table = opened_stream.read_all()
return pa_table
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]
print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, x in enumerate(arr):
if idx % 100 == 0:
gc.collect()
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
which when run produced the familiar leak-like pattern:
$ python mmap-no-leak-debug-pyarrow.py
idx RSS Δ RSS
0 51.3164MB 2.5430MB
100 69.9805MB 21.2070MB
200 90.6055MB 41.8320MB
300 107.1055MB 58.3320MB
400 127.7305MB 78.9570MB
500 148.3555MB 99.5820MB
600 164.8555MB 116.0820MB
700 185.4805MB 136.7070MB
800 206.1055MB 157.3320MB
900 226.7305MB 177.9570MB
But if we run it from a shell that is only allowed 100MB of allocated memory:
$ systemd-run --user --scope -p MemoryHigh=0.1G -p MemoryMax=0.1G -p MemorySwapMax=0G --setenv="MEMLIMIT=0.1GB" bash
$ python mmap-no-leak-debug-pyarrow.py
idx RSS Δ RSS
0 51.2852MB 2.4609MB
100 70.4102MB 21.5859MB
200 86.9102MB 38.0859MB
300 107.5352MB 58.7109MB
400 128.1602MB 79.3359MB
500 148.7852MB 99.9609MB
600 165.2852MB 116.4609MB
700 185.9102MB 137.0859MB
800 206.5352MB 157.7109MB
900 227.1602MB 178.3359MB
So it reports it allocated ~200MB of RSS, yet it runs just fine without getting killed.
There is no leak here.
What about HuggingFace datasets?
In another Issue a very similar datasets-iterator-is-leaking report was submitted.
So let’s use a similar datasets
reproduction example here but we will use a larger dataset.
$ cat mmap-no-leak-debug-datasets.py
from datasets import load_dataset
import gc
import os
import psutil
import sys
keep_in_memory = True if "in-mem" in sys.argv else False
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=keep_in_memory, streaming=keep_in_memory)['train']
print(f"Dataset len={len(dataset)}")
print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
step = 1_000_000
mem_start = 0
for idx, i in enumerate(range(0, len(dataset), step)):
if idx == 4: # skip the first few iterations while things get set up
mem_start = mem_read()
mem_before = mem_read()
x = dataset[i:i+step]
mem_after = mem_read()
print(f"{idx:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB ")
mem_end = mem_read()
print(f"Total diff: {mem_end - mem_start:12.4f}MB ")
Let’s run it in a normal shell first:
$ python mmap-no-leak-debug-datasets.py
Dataset len=7270695
idx RSS Δ RSS
0 775.7773MB 609.9805MB
1 849.6016MB 73.8242MB
2 876.1445MB 26.5430MB
3 941.3477MB 65.2031MB
4 984.9570MB 43.6094MB
5 1053.6445MB 68.6875MB
6 1164.2852MB 110.6406MB
7 1252.5312MB 88.2461MB
8 1368.6523MB 116.1211MB
9 1445.7266MB 77.0742MB
10 1564.5195MB 118.7930MB
11 1678.7500MB 114.2305MB
12 1729.9844MB 51.2344MB
13 1866.1953MB 136.2109MB
Total diff: 1700.3984MB
You can see the mid-column of total RSS memory keeps on growing in MBs. The last column is by how much it has grown during a single iteration of the script (0.5M items).
And now let’s run in a 1GB limited shell:
$ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=0G --setenv="MEMLIMIT=1GB" bash
$ python mmap-no-leak-debug-datasets.py
Dataset len=7270695
idx RSS Δ RSS
0 775.8516MB 610.1797MB
1 849.5820MB 73.7305MB
2 876.1328MB 26.5508MB
3 941.3281MB 65.1953MB
4 984.9375MB 43.6094MB
5 1053.6328MB 68.6953MB
6 1164.0273MB 110.3945MB
7 1252.5273MB 88.5000MB
8 1368.3906MB 115.8633MB
9 1445.7188MB 77.3281MB
10 1564.2656MB 118.5469MB
11 1678.7383MB 114.4727MB
12 1729.7227MB 50.9844MB
13 1866.1875MB 136.4648MB
Total diff: 1700.5156MB
No problem at all.
So we now know there is no leak there and it’s just the OS includes in RSS memory that will be released as soon as it’s needed.
How to debbug real memory leaks while using MMAP
So how does one debug an actual memory that might be elsewhere in the code while using MMAP.
Well, you have to disable MMAP for the duration of your debug session and then re-enabled it back when you want high performance.
As you have seen at the beginning of this article switching from mmap
to normal IO is very simple to do.
In the case of datasets
you’d turn MMAP functionality off with keep_in_memory=True
as in:
dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=True, streaming=False)['train']
This loads the dataset in RAM, and now you should be able to debug your potential leak.
Let’s test after modifying our last program:
- dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=False, streaming=False)['train']
+ dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=True, streaming=False)['train']
Now in the normal unlimited shell we run:
$ python mmap-no-leak-debug-datasets.py --in-mem
Dataset len=7270695
idx RSS Δ RSS
0 1849.5391MB 469.5781MB
1 1833.0391MB -16.5000MB
2 1803.4609MB -29.5781MB
3 1811.5312MB 8.0703MB
4 1803.9531MB -7.5781MB
5 1811.7734MB 7.8203MB
6 1836.0391MB 24.2656MB
7 1839.5938MB 3.5547MB
8 1855.9688MB 16.3750MB
9 1850.5430MB -5.4258MB
10 1865.3398MB 14.7969MB
11 1876.2461MB 10.9062MB
12 1853.0469MB -23.1992MB
13 1881.4453MB 28.3984MB
Total diff: 501.4844MB
The RSS memory is more stable but fluctuates because the records are different, and the dataset can be huge to load into memory.
Using synthetic MMAP-disabled dataset to debug memory leaks
Therefore the easiest approach is to create a synthetic dataset of desired length with all records being the same. That way the data is no longer a factor in the memory usage patterns as it’s always the same.
$ cat ds-synthetic-no-mmap.py
from datasets import load_from_disk, Dataset
import gc
import sys
import os
import psutil
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
DS_PATH = "synthetic-ds"
if not os.path.exists(DS_PATH):
records = 1_000_000
print("Creating a synthetic dataset")
row = dict(foo=[dict(a='a'*500, b='b'*1000)])
ds = Dataset.from_dict({k: [v] * records for k, v in row.items()})
ds.save_to_disk(DS_PATH)
print("Done. Please restart the program")
sys.exit()
dataset = load_from_disk(DS_PATH, keep_in_memory=True)
print(f"Dataset len={len(dataset)}")
print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_start = 0
step = 50_000
warmup_iterations = 4
for idx, i in enumerate(range(0, len(dataset), step)):
if idx == warmup_iterations: # skip the first few iterations while things get set up
mem_start = mem_read()
mem_before = mem_read()
_ = dataset[i:i+step]
mem_after = mem_read()
print(f"{i:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
mem_end = mem_read()
print(f"Total diff: {mem_end - mem_start:12.4f}MB (after {warmup_iterations} warmup iterations)")
We run this program once to create the dataset, and then the second time to profile its memory usage:
$ python ds-synthetic-no-mmap.py
Creating a synthetic dataset
Done. Please restart the program
$ python ds-synthetic-no-mmap.py
Dataset len=1000000
idx RSS Δ RSS
0 1649.6055MB 95.1992MB
50000 1728.4961MB 78.8906MB
100000 1728.7109MB 0.2148MB
150000 1729.2539MB 0.5430MB
200000 1729.0039MB -0.2500MB
250000 1729.5039MB 0.5000MB
300000 1729.2539MB -0.2500MB
350000 1729.7539MB 0.5000MB
400000 1729.5039MB -0.2500MB
450000 1730.0039MB 0.5000MB
500000 1729.7539MB -0.2500MB
550000 1730.2539MB 0.5000MB
600000 1730.0039MB -0.2500MB
650000 1730.5039MB 0.5000MB
700000 1730.2539MB -0.2500MB
750000 1730.7539MB 0.5000MB
800000 1730.5039MB -0.2500MB
850000 1731.0039MB 0.5000MB
900000 1730.7539MB -0.2500MB
950000 1731.2539MB 0.5000MB
Total diff: 2.0000MB (after 4 warmup iterations)
This is much better. There are still tiny fluctuations due to Python and you can see in the code I skipped the first few iterations in the code while things are being set up.
But otherwise now you can easily debug the rest of your code for any memory leaks since datasets
are in non-MMAP mode and the records size doesn’t fluctuate.
Of course, do not forget to flip load_from_disk(..., keep_in_memory=True)
to False
when the debugging process is over so that you get back the performance speed up provided by MMAP.
I wrote these notes mainly for myself to ensure I have a good understanding of this complex use-case. And I hope you have gained some understanding from it as well.
Tags: datasets, machine learning, mmap, pyarrow, python
Leave a Reply