Skip to content

Latest commit

ย 

History

History
230 lines (162 loc) ยท 8.94 KB

torchrec_tutorial.rst

File metadata and controls

230 lines (162 loc) ยท 8.94 KB

TorchRec ์†Œ๊ฐœ

Tip

์ด ํŠœํ† ๋ฆฌ์–ผ์„ ์ตœ๋Œ€ํ•œ ํ™œ์šฉํ•˜๋ ค๋ฉด ์ด Colab ๋ฒ„์ „ ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์•„๋ž˜์— ์ œ์‹œ๋œ ์ •๋ณด๋ฅผ ์‹คํ—˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์•„๋ž˜ ๋™์˜์ƒ์ด๋‚˜ ์œ ํŠœ๋ธŒ ์—์„œ ๋”ฐ๋ผํ•ด๋ณด์„ธ์š”.

์ถ”์ฒœ ์‹œ์Šคํ…œ์„ ๋งŒ๋“ค ๋•Œ, ์ œํ’ˆ์ด๋‚˜ ํŽ˜์ด์ง€์™€ ๊ฐ™์€ ๊ฐ์ฒด๋ฅผ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ํ‘œํ˜„ํ•˜๊ณ  ์‹ถ์€ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. Meta AI์˜ ๋”ฅ๋Ÿฌ๋‹ ์ถ”์ฒœ ๋ชจ๋ธ ๋˜๋Š” DLRM์„ ์˜ˆ๋กœ ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ์ฒด์˜ ์ˆ˜๊ฐ€ ์ฆ๊ฐ€ํ•จ์— ๋”ฐ๋ผ, ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์˜ ํฌ๊ธฐ๊ฐ€ ๋‹จ์ผ GPU์˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ดˆ๊ณผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ ๋ฐฉ๋ฒ•์€ ๋ชจ๋ธ ๋ณ‘๋ ฌํ™”์˜ ์ผ์ข…์œผ๋กœ, ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์„ ์—ฌ๋Ÿฌ ๋””๋ฐ”์ด์Šค๋กœ ์ƒค๋”ฉ(shard)ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด, TorchRec์€ DistributedModelParallel ๋˜๋Š” DMP๋กœ ๋ถˆ๋ฆฌ๋Š” ์ฃผ์š”ํ•œ API๋ฅผ ์†Œ๊ฐœํ•ฉ๋‹ˆ๋‹ค. PyTorch์˜ DistributedDataParallel์™€ ๊ฐ™์ด, DMP๋Š” ๋ถ„์‚ฐ ํ•™์Šต์„ ๊ฐ€๋Šฅํ•˜๊ฒŒํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ์„ ํฌ์žฅํ•ฉ๋‹ˆ๋‹ค.

์„ค์น˜

์š”๊ตฌ ์‚ฌํ•ญ: python >= 3.7

TorchRec์„ ์‚ฌ์šฉํ•  ๋•Œ๋Š” CUDA๋ฅผ ์ ๊ทน ์ถ”์ฒœํ•ฉ๋‹ˆ๋‹ค. (CUDA๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ: cuda >= 11.0)

# install pytorch with cudatoolkit 11.3
conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y
# install TorchTec
pip3 install torchrec-nightly

๊ฐœ์š”

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” TorchRec์˜ nn.module EmbeddingBagCollection, DistributedModelParallel API, ๋ฐ์ดํ„ฐ ๊ตฌ์กฐ KeyedJaggedTensor 3๊ฐ€์ง€ ๋‚ด์šฉ์„ ๋‹ค๋ฃน๋‹ˆ๋‹ค.

๋ถ„์‚ฐ ์„ค์ •

torch.distributed๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ™˜๊ฒฝ์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ๋ถ„์‚ฐ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ด ํŠœํ† ๋ฆฌ์–ผ ์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

์—ฌ๊ธฐ์„œ๋Š” 1๊ฐœ์˜ colab GPU์— ๋Œ€์‘ํ•˜๋Š” 1๊ฐœ์˜ ๋žญํฌ(colab ํ”„๋กœ์„ธ์Šค)๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

import os
import torch
import torchrec
import torch.distributed as dist

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

# ์ฐธ๊ณ  - ํŠœํ† ๋ฆฌ์–ผ์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด V100 ๋˜๋Š” A100์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค!
# colab free K80๊ณผ ๊ฐ™์€ ์˜ค๋ž˜๋œ GPU๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค๋ฉด,
# ์ ์ ˆํ•œ CUDA ์•„ํ‚คํ…์ฒ˜๋กœ fbgemm๋ฅผ ์ปดํŒŒ์ผํ•˜๊ฑฐ๋‚˜,
# CPU์—์„œ "gloo"๋กœ ์‹คํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
dist.init_process_group(backend="nccl")

EmbeddingBag์—์„œ EmbeddingBagCollection์œผ๋กœ

PyTorch๋Š” torch.nn.Embedding ์™€ torch.nn.EmbeddingBag ๋ฅผ ํ†ตํ•ด ์ž„๋ฒ ๋”ฉ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. EmbeddingBag์€ ์ž„๋ฒ ๋”ฉ์˜ ํ’€(pool) ๋ฒ„์ „์ž…๋‹ˆ๋‹ค.

TorchRec์€ ์ž„๋ฒ ๋”ฉ ์ปฌ๋ ‰์…˜์„ ์ƒ์„ฑํ•˜์—ฌ ์ด ๋ชจ๋“ˆ๋“ค์„ ํ™•์žฅํ•ฉ๋‹ˆ๋‹ค. EmbeddingBag ๊ทธ๋ฃน์„ ๋‚˜ํƒ€๋‚ด๊ณ ์ž EmbeddingBagCollection ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ๋Š”, 2๊ฐœ์˜ EmbeddingBag์„ ๊ฐ€์ง€๋Š” EmbeddingBagCollection (EBC)์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ํ…Œ์ด๋ธ” product_table ๊ณผ user_table ๋Š” 4096 ํฌ๊ธฐ์˜ 64 ์ฐจ์› ์ž„๋ฒ ๋”ฉ์œผ๋กœ ํ‘œํ˜„๋ฉ๋‹ˆ๋‹ค. โ€œmetaโ€ ๋””๋ฐ”์ด์Šค์—์„œ EBC๋ฅผ ์ดˆ๊ธฐ์— ํ• ๋‹นํ•˜๋Š” ๋ฐฉ๋ฒ•์— ์ฃผ์˜ํ•˜์„ธ์š”. EBC์—๊ฒŒ ์•„์ง ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ• ๋‹น๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.

ebc = torchrec.EmbeddingBagCollection(
    device="meta",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        )
    ]
)

DistributedModelParallel

์ด์ œ ๋ชจ๋ธ์„ DistributedModelParallel (DMP)๋กœ ๊ฐ์Œ€ ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. DMP์˜ ์ธ์Šคํ„ด์Šคํ™”๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  1. ๋ชจ๋ธ์„ ์ƒค๋”ฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค. DMP๋Š” ์ด์šฉ ๊ฐ€๋Šฅํ•œ โ€˜shardersโ€™๋ฅผ ์ˆ˜์ง‘ํ•˜๊ณ  ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์„ ์ƒค๋”ฉํ•˜๋Š” ์ตœ์ ์˜ ๋ฐฉ๋ฒ• (์ฆ‰, the EmbeddingBagCollection)์˜ โ€˜planโ€™์„ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.
  2. ๋ชจ๋ธ์„ ์ƒค๋”ฉํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ๊ฐ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์„ ์ ์ ˆํ•œ ์žฅ์น˜๋กœ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•˜๋Š” ๊ฒƒ์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.

์ด ์˜ˆ์ œ์—์„œ๋Š” 2๊ฐœ์˜ EmbeddingTables๊ณผ ํ•˜๋‚˜์˜ GPU๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์—, TorchRec์€ ๋ชจ๋‘ ๋‹จ์ผ GPU์— ๋ฐฐ์น˜ํ•ฉ๋‹ˆ๋‹ค.

model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
print(model)
print(model.plan)

์ž…๋ ฅ๊ณผ ์˜คํ”„์…‹์ด ์žˆ๋Š” ๊ธฐ๋ณธ nn.EmbeddingBag ์งˆ์˜

input ๊ณผ offsets ์ด ์žˆ๋Š” nn.Embedding ๊ณผ nn.EmbeddingBag ๋ฅผ ์งˆ์˜ํ•ฉ๋‹ˆ๋‹ค. ์ž…๋ ฅ์€ lookup ๊ฐ’์„ ํฌํ•จํ•˜๋Š” 1-D ํ…์„œ์ž…๋‹ˆ๋‹ค. ์˜คํ”„์…‹์€ ์‹œํ€€์Šค๊ฐ€ ๊ฐ ์˜ˆ์ œ์—์„œ ๊ฐ€์ ธ์˜ค๋Š” ๊ฐ’์˜ ์ˆ˜์˜ ํ•ฉ์ธ 1-D ํ…์„œ์ž…๋‹ˆ๋‹ค.

์œ„์˜ EmbeddingBag์„ ๋‹ค์‹œ ๋งŒ๋“ค์–ด๋ณด๋Š” ์˜ˆ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

|------------|
| product ID |
|------------|
| [101, 202] |
| []         |
| [303]      |
|------------|
product_eb = torch.nn.EmbeddingBag(4096, 64)
product_eb(input=torch.tensor([101, 202, 303]), offsets=torch.tensor([0, 2, 2]))

KeyedJaggedTensor๋กœ ๋ฏธ๋‹ˆ ๋ฐฐ์น˜ ํ‘œํ˜„ํ•˜๊ธฐ

์˜ˆ์ œ ๋ฐ ๊ธฐ๋Šฅ๋ณ„๋กœ ๊ฐ์ฒด ID๊ฐ€ ์ž„์˜์˜ ์ˆ˜์ธ ๋‹ค์–‘ํ•œ ์˜ˆ์ œ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ๋‚˜ํƒ€๋‚ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์–‘ํ•œ ํ‘œํ˜„์ด ๊ฐ€๋Šฅํ•˜๋„๋ก, TorchRec ๋ฐ์ดํ„ฐ๊ตฌ์กฐ KeyedJaggedTensor (KJT)๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

โ€œproductโ€ ์™€ โ€œuserโ€, 2๊ฐœ์˜ EmbeddingBag์˜ ์ปฌ๋ ‰์…˜์„ ์ฐธ์กฐํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ด…๋‹ˆ๋‹ค. ๋ฏธ๋‹ˆ๋ฐฐ์น˜๊ฐ€ 3๋ช…์˜ ์‚ฌ์šฉ์ž์™€ 3๊ฐœ์˜ ์˜ˆ์ œ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ๋Š” 2๊ฐœ์˜ product ID๋ฅผ ๊ฐ€์ง€๊ณ , ๋‘ ๋ฒˆ์งธ๋Š” ์•„๋ฌด๊ฒƒ๋„ ๊ฐ€์ง€์ง€ ์•Š๊ณ , ์„ธ ๋ฒˆ์งธ๋Š” ํ•˜๋‚˜์˜ product ID๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

|------------|------------|
| product ID | user ID    |
|------------|------------|
| [101, 202] | [404]      |
| []         | [505]      |
| [303]      | [606]      |
|------------|------------|

์งˆ์˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

mb = torchrec.KeyedJaggedTensor(
    keys = ["product", "user"],
    values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
    lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),
)

print(mb.to(torch.device("cpu")))

KJT ๋ฐฐ์น˜ ํฌ๊ธฐ๋Š” batch_size = len(lengths)//len(keys) ์ธ ๊ฒƒ์„ ๋ˆˆ์—ฌ๊ฒจ๋ด ์ฃผ์„ธ์š”. ์œ„ ์˜ˆ์ œ์—์„œ batch_size๋Š” 3์ž…๋‹ˆ๋‹ค.

์ด์ •๋ฆฌํ•˜์—ฌ, KJT ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ„์‚ฐ ๋ชจ๋ธ ์งˆ์˜ํ•˜๊ธฐ

๋งˆ์ง€๋ง‰์œผ๋กœ ์ œํ’ˆ๊ณผ ์‚ฌ์šฉ์ž์˜ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์งˆ์˜ํ•ฉ๋‹ˆ๋‹ค.

๊ฒฐ๊ณผ ์กฐํšŒ๋Š” KeyedTensor๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ํ‚ค(key) ๋˜๋Š” ํŠน์ง•(feature)์€ ํฌ๊ธฐ๊ฐ€ 3x64 (batch_size x embedding_dim)์ธ 2D ํ…์„œ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.

pooled_embeddings = model(mb)
print(pooled_embeddings)

์ถ”๊ฐ€ ์ž๋ฃŒ

์ž์„ธํ•œ ๋‚ด์šฉ์€ dlrm ์˜ˆ์ œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”. ์ด ์˜ˆ์ œ๋Š” Meta์˜ DLRM ์„ ์‚ฌ์šฉํ•˜์—ฌ 1ํ…Œ๋ผ๋ฐ”์ดํŠธ ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•œ ๋ฉ€ํ‹ฐ ๋…ธ๋“œ ํ•™์Šต์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.