Tip
์ด ํํ ๋ฆฌ์ผ์ ์ต๋ํ ํ์ฉํ๋ ค๋ฉด ์ด Colab ๋ฒ์ ์ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ด๋ฅผ ํตํด ์๋์ ์ ์๋ ์ ๋ณด๋ฅผ ์คํํ ์ ์์ต๋๋ค.
์๋ ๋์์์ด๋ ์ ํ๋ธ ์์ ๋ฐ๋ผํด๋ณด์ธ์.
์ถ์ฒ ์์คํ
์ ๋ง๋ค ๋, ์ ํ์ด๋ ํ์ด์ง์ ๊ฐ์ ๊ฐ์ฒด๋ฅผ ์๋ฒ ๋ฉ์ผ๋ก ํํํ๊ณ ์ถ์ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค.
Meta AI์ ๋ฅ๋ฌ๋ ์ถ์ฒ ๋ชจ๋ธ ๋๋ DLRM์ ์๋ก ๋ค ์ ์์ต๋๋ค.
๊ฐ์ฒด์ ์๊ฐ ์ฆ๊ฐํจ์ ๋ฐ๋ผ, ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ํฌ๊ธฐ๊ฐ ๋จ์ผ GPU์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ด๊ณผํ ์ ์์ต๋๋ค.
์ผ๋ฐ์ ์ธ ๋ฐฉ๋ฒ์ ๋ชจ๋ธ ๋ณ๋ ฌํ์ ์ผ์ข
์ผ๋ก, ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ์ฌ๋ฌ ๋๋ฐ์ด์ค๋ก ์ค๋ฉ(shard)ํ๋ ๊ฒ์
๋๋ค.
์ด๋ฅผ ์ํด, TorchRec์ DistributedModelParallel
๋๋ DMP๋ก ๋ถ๋ฆฌ๋ ์ฃผ์ํ API๋ฅผ ์๊ฐํฉ๋๋ค.
PyTorch์ DistributedDataParallel์ ๊ฐ์ด, DMP๋ ๋ถ์ฐ ํ์ต์ ๊ฐ๋ฅํ๊ฒํ๊ธฐ ์ํด ๋ชจ๋ธ์ ํฌ์ฅํฉ๋๋ค.
์๊ตฌ ์ฌํญ: python >= 3.7
TorchRec์ ์ฌ์ฉํ ๋๋ CUDA๋ฅผ ์ ๊ทน ์ถ์ฒํฉ๋๋ค. (CUDA๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ: cuda >= 11.0)
# install pytorch with cudatoolkit 11.3
conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y
# install TorchTec
pip3 install torchrec-nightly
์ด ํํ ๋ฆฌ์ผ์์๋ TorchRec์ nn.module
EmbeddingBagCollection
, DistributedModelParallel
API,
๋ฐ์ดํฐ ๊ตฌ์กฐ KeyedJaggedTensor
3๊ฐ์ง ๋ด์ฉ์ ๋ค๋ฃน๋๋ค.
torch.distributed๋ฅผ ์ฌ์ฉํ์ฌ ํ๊ฒฝ์ ์ค์ ํฉ๋๋ค. ๋ถ์ฐ์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ด ํํ ๋ฆฌ์ผ ์ ์ฐธ๊ณ ํ์ธ์.
์ฌ๊ธฐ์๋ 1๊ฐ์ colab GPU์ ๋์ํ๋ 1๊ฐ์ ๋ญํฌ(colab ํ๋ก์ธ์ค)๋ฅผ ์ฌ์ฉํฉ๋๋ค.
import os
import torch
import torchrec
import torch.distributed as dist
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
# ์ฐธ๊ณ - ํํ ๋ฆฌ์ผ์ ์คํํ๋ ค๋ฉด V100 ๋๋ A100์ด ํ์ํฉ๋๋ค!
# colab free K80๊ณผ ๊ฐ์ ์ค๋๋ GPU๋ฅผ ์ฌ์ฉํ๋ค๋ฉด,
# ์ ์ ํ CUDA ์ํคํ
์ฒ๋ก fbgemm๋ฅผ ์ปดํ์ผํ๊ฑฐ๋,
# CPU์์ "gloo"๋ก ์คํํด์ผ ํฉ๋๋ค.
dist.init_process_group(backend="nccl")
PyTorch๋ torch.nn.Embedding
์ torch.nn.EmbeddingBag
๋ฅผ ํตํด ์๋ฒ ๋ฉ์ ๋ํ๋
๋๋ค.
EmbeddingBag์ ์๋ฒ ๋ฉ์ ํ(pool) ๋ฒ์ ์
๋๋ค.
TorchRec์ ์๋ฒ ๋ฉ ์ปฌ๋ ์
์ ์์ฑํ์ฌ ์ด ๋ชจ๋๋ค์ ํ์ฅํฉ๋๋ค.
EmbeddingBag ๊ทธ๋ฃน์ ๋ํ๋ด๊ณ ์ EmbeddingBagCollection
์ ์ฌ์ฉํฉ๋๋ค.
์ฌ๊ธฐ์๋, 2๊ฐ์ EmbeddingBag์ ๊ฐ์ง๋ EmbeddingBagCollection (EBC)์ ์์ฑํฉ๋๋ค.
๊ฐ ํ
์ด๋ธ product_table
๊ณผ user_table
๋ 4096 ํฌ๊ธฐ์ 64 ์ฐจ์ ์๋ฒ ๋ฉ์ผ๋ก ํํ๋ฉ๋๋ค.
โmetaโ ๋๋ฐ์ด์ค์์ EBC๋ฅผ ์ด๊ธฐ์ ํ ๋นํ๋ ๋ฐฉ๋ฒ์ ์ฃผ์ํ์ธ์. EBC์๊ฒ ์์ง ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ ๋น๋์ง ์์์ต๋๋ค.
ebc = torchrec.EmbeddingBagCollection(
device="meta",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
pooling=torchrec.PoolingType.SUM,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
pooling=torchrec.PoolingType.SUM,
)
]
)
์ด์ ๋ชจ๋ธ์ DistributedModelParallel
(DMP)๋ก ๊ฐ์ ์ค๋น๊ฐ ๋์์ต๋๋ค.
DMP์ ์ธ์คํด์คํ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ๋ชจ๋ธ์ ์ค๋ฉํ๋ ๋ฐฉ๋ฒ์ ๊ฒฐ์ ํฉ๋๋ค. DMP๋ ์ด์ฉ ๊ฐ๋ฅํ โshardersโ๋ฅผ ์์งํ๊ณ ์๋ฒ ๋ฉ ํ ์ด๋ธ์ ์ค๋ฉํ๋ ์ต์ ์ ๋ฐฉ๋ฒ (์ฆ, the EmbeddingBagCollection)์ โplanโ์ ์์ฑํฉ๋๋ค.
- ๋ชจ๋ธ์ ์ค๋ฉํฉ๋๋ค. ์ด ๊ณผ์ ์ ๊ฐ ์๋ฒ ๋ฉ ํ ์ด๋ธ์ ์ ์ ํ ์ฅ์น๋ก ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํ๋ ๊ฒ์ ํฌํจํฉ๋๋ค.
์ด ์์ ์์๋ 2๊ฐ์ EmbeddingTables๊ณผ ํ๋์ GPU๊ฐ ์๊ธฐ ๋๋ฌธ์, TorchRec์ ๋ชจ๋ ๋จ์ผ GPU์ ๋ฐฐ์นํฉ๋๋ค.
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
print(model)
print(model.plan)
input
๊ณผ offsets
์ด ์๋ nn.Embedding
๊ณผ nn.EmbeddingBag
๋ฅผ ์ง์ํฉ๋๋ค.
์
๋ ฅ์ lookup ๊ฐ์ ํฌํจํ๋ 1-D ํ
์์
๋๋ค.
์คํ์
์ ์ํ์ค๊ฐ ๊ฐ ์์ ์์ ๊ฐ์ ธ์ค๋ ๊ฐ์ ์์ ํฉ์ธ 1-D ํ
์์
๋๋ค.
์์ EmbeddingBag์ ๋ค์ ๋ง๋ค์ด๋ณด๋ ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
|------------| | product ID | |------------| | [101, 202] | | [] | | [303] | |------------|
product_eb = torch.nn.EmbeddingBag(4096, 64)
product_eb(input=torch.tensor([101, 202, 303]), offsets=torch.tensor([0, 2, 2]))
์์ ๋ฐ ๊ธฐ๋ฅ๋ณ๋ก ๊ฐ์ฒด ID๊ฐ ์์์ ์์ธ ๋ค์ํ ์์ ๋ฅผ ํจ์จ์ ์ผ๋ก ๋ํ๋ด์ผ ํฉ๋๋ค.
๋ค์ํ ํํ์ด ๊ฐ๋ฅํ๋๋ก, TorchRec ๋ฐ์ดํฐ๊ตฌ์กฐ KeyedJaggedTensor
(KJT)๋ฅผ ์ฌ์ฉํฉ๋๋ค.
โproductโ ์ โuserโ, 2๊ฐ์ EmbeddingBag์ ์ปฌ๋ ์ ์ ์ฐธ์กฐํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ด ๋๋ค. ๋ฏธ๋๋ฐฐ์น๊ฐ 3๋ช ์ ์ฌ์ฉ์์ 3๊ฐ์ ์์ ๋ก ๊ตฌ์ฑ๋์ด ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. ์ฒซ ๋ฒ์งธ๋ 2๊ฐ์ product ID๋ฅผ ๊ฐ์ง๊ณ , ๋ ๋ฒ์งธ๋ ์๋ฌด๊ฒ๋ ๊ฐ์ง์ง ์๊ณ , ์ธ ๋ฒ์งธ๋ ํ๋์ product ID๋ฅผ ๊ฐ์ง๋๋ค.
|------------|------------| | product ID | user ID | |------------|------------| | [101, 202] | [404] | | [] | [505] | | [303] | [606] | |------------|------------|
์ง์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
mb = torchrec.KeyedJaggedTensor(
keys = ["product", "user"],
values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),
)
print(mb.to(torch.device("cpu")))
KJT ๋ฐฐ์น ํฌ๊ธฐ๋ batch_size = len(lengths)//len(keys)
์ธ ๊ฒ์ ๋์ฌ๊ฒจ๋ด ์ฃผ์ธ์.
์ ์์ ์์ batch_size๋ 3์
๋๋ค.
๋ง์ง๋ง์ผ๋ก ์ ํ๊ณผ ์ฌ์ฉ์์ ๋ฏธ๋๋ฐฐ์น๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ง์ํฉ๋๋ค.
๊ฒฐ๊ณผ ์กฐํ๋ KeyedTensor๋ฅผ ํฌํจํฉ๋๋ค. ๊ฐ ํค(key) ๋๋ ํน์ง(feature)์ ํฌ๊ธฐ๊ฐ 3x64 (batch_size x embedding_dim)์ธ 2D ํ ์๋ฅผ ํฌํจํฉ๋๋ค.
pooled_embeddings = model(mb)
print(pooled_embeddings)
์์ธํ ๋ด์ฉ์ dlrm ์์ ๋ฅผ ์ฐธ๊ณ ํ์ธ์. ์ด ์์ ๋ Meta์ DLRM ์ ์ฌ์ฉํ์ฌ 1ํ ๋ผ๋ฐ์ดํธ ๋ฐ์ดํฐ์ ์ ๋ํ ๋ฉํฐ ๋ ธ๋ ํ์ต์ ํฌํจํฉ๋๋ค.