From f0ea511c5128eec33b5e3c3599bc6dc05cf943f2 Mon Sep 17 00:00:00 2001 From: Michael Fromm Date: Tue, 30 Jan 2024 18:12:22 +0100 Subject: [PATCH 1/9] feat: group-query-attention implementation --- src/modalities/models/gpt2/gpt2_model.py | 28 ++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index f5356710..eef66c44 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -40,7 +40,8 @@ class GPT2Config(BaseModel): block_size: conint(ge=1) vocab_size: conint(ge=1) # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: conint(ge=1) - n_head: conint(ge=1) + n_head_q: conint(ge=1) + n_head_kv: conint(ge=1) n_embd: conint(ge=1) ffn_hidden: conint(ge=1) dropout: confloat(ge=0.0) @@ -82,10 +83,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class CausalSelfAttention(nn.Module): def __init__( - self, n_head: int, n_embd: int, attention: AttentionConfig, bias: bool, dropout: float, block_size: int + self, n_head_q: int, n_head_kv: int, n_embd: int, attention: AttentionConfig, bias: bool, dropout: float, block_size: int ): super().__init__() - assert n_embd % n_head == 0 + assert n_embd % n_head_q == 0 # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear( in_features=n_embd, @@ -103,7 +104,9 @@ def __init__( # regularization self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) - self.n_head = n_head + self.n_head_q = n_head_q + self.n_head_kv = n_head_kv + self.n_embd = n_embd self.dropout = dropout self.flash = attention.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION @@ -120,9 +123,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + k = k.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head_q, C // self.n_head_q).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: @@ -180,7 +183,8 @@ def __init__( bias: bool, epsilon: float, activation: ActivationType, - n_head: int, + n_head_q: int, + n_head_kv: int, attention: AttentionConfig, dropout: float, block_size: int, @@ -189,7 +193,7 @@ def __init__( super().__init__() self.ln_1 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon) self.attn = CausalSelfAttention( - n_head=n_head, n_embd=n_embd, attention=attention, bias=bias, dropout=dropout, block_size=block_size + n_head_q=n_head_q, n_head_kv=n_head_kv, n_embd=n_embd, attention=attention, bias=bias, dropout=dropout, block_size=block_size ) self.ln_2 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon) @@ -215,7 +219,8 @@ def __init__( block_size: int, vocab_size: int, n_layer: int, - n_head: int, + n_head_q: int, + n_head_kv: int, n_embd: int, ffn_hidden: int, dropout: float, @@ -245,7 +250,8 @@ def __init__( bias=bias, epsilon=epsilon, activation=activation, - n_head=n_head, + n_head_q=n_head_q, + n_head_kv=n_head_kv, attention=attention, dropout=dropout, block_size=block_size, From 6f86f1b6a4bd4f1801c60def32c9a133b4b14786 Mon Sep 17 00:00:00 2001 From: Luzian Hahn Date: Mon, 11 Mar 2024 09:58:02 +0100 Subject: [PATCH 2/9] chore: merge main into GQA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 080755503c12ba250b83ba2864d993d4a73dd934 Author: Max Luebbering Date: Thu Mar 7 18:33:39 2024 +0100 refactor: deleted failing legacy test commit dd0db07bbe631e9dc30f35912076d26603f4a6b7 Merge: 095e491 4821804 Author: Luzian Hahn <145655920+luzian-hahn@users.noreply.github.com> Date: Thu Mar 7 10:29:09 2024 +0100 Merge pull request #48 from Modalities/feat/merge-pbin-files feat: merge utility for pbin files commit 48218040a00fb7ed8380e4cce954e70eda3362c9 Author: Luzian Hahn Date: Thu Mar 7 10:27:28 2024 +0100 docs: add hint about updated header structure commit b34d6cb4463573d95e138a09d73527086495c7ac Author: Luzian Hahn Date: Thu Mar 7 10:19:54 2024 +0100 refactor: remove unused utility commit 7d05448a600fe8374fab0ffbd1bcff9cfbe757ad Author: Luzian Hahn Date: Tue Feb 27 16:00:38 2024 +0100 refactor: remove redundant check for valid pbin files commit 2e273359df02e808be329c8ff73f8f9ef6ea1485 Author: Luzian Hahn Date: Mon Feb 5 18:21:51 2024 +0100 feat: add entrypoint for pbin-merge commit 8ffc095f9e5020ea61ce9ab5af5516875f526094 Author: Luzian Hahn Date: Mon Feb 5 18:16:06 2024 +0100 refactor: introduce entrypoint group "data" commit a0d13a366201734fd3cd4f257c756b607de2067b Author: Luzian Hahn Date: Mon Feb 5 15:06:18 2024 +0100 feat: add pbin-merger commit 9f853cf9e25cfa4c030500a55e2a7a4feb9f5d29 Author: Luzian Hahn Date: Mon Feb 5 11:36:49 2024 +0100 refactor: introduce abstraction for stream data below packed Datasets commit 095e491031e7e93e6a365c2c0d81a176baff1f85 Merge: 419fc9e 0f3846a Author: Luzian Hahn <145655920+luzian-hahn@users.noreply.github.com> Date: Thu Mar 7 09:38:53 2024 +0100 Merge pull request #40 from Modalities/perf/benchmark-datasets-again-megatronlm perf: benchmark datasets against megatronlm commit 0f3846a829035ddc0e36b9e548c1c29d3af86213 Author: Luzian Hahn Date: Thu Mar 7 09:28:27 2024 +0100 test: prevent unnecessary warnings during tests commit f2232c3e99417194736ee7157cf6efc4d27a49d9 Merge: 9095ac5 419fc9e Author: Luzian Hahn <145655920+luzian-hahn@users.noreply.github.com> Date: Thu Mar 7 08:45:14 2024 +0100 Merge branch 'main' into perf/benchmark-datasets-again-megatronlm commit 419fc9e458422eeb1f8f389ecdba675f0f6e9d41 Merge: 8ab29d0 d192331 Author: Max Lübbering <2804731+le1nux@users.noreply.github.com> Date: Mon Mar 4 12:25:00 2024 +0100 Merge pull request #65 from David-Berghaus/Fix-typos Fixed typos commit d19233198b9e2d62eda2d5734b21ad48865dc68b Author: David Berghaus Date: Mon Mar 4 12:12:47 2024 +0100 Fixed typos commit 8ab29d0c1262ca82bbe15409934bfbc28187408a Merge: d71bceb f9b0f41 Author: Mehdi Ali <33023925+mali-git@users.noreply.github.com> Date: Fri Mar 1 15:59:01 2024 +0100 Merge pull request #45 from Modalities/hierarchical_instantiation Hierarchical instantiation commit f9b0f416c3450c6e4e5c174aa1df89220a86eeba Author: Felix Stollenwerk Date: Mon Feb 26 16:53:36 2024 +0100 chore: fix linting commit 042e3a03e2e788d644486c6121288a5aecf14bfd Author: Felix Stollenwerk Date: Mon Feb 26 16:00:39 2024 +0100 refactor: fix typos commit 8345e06cfdce42bcda0d7cefce1f2bf3411d159d Author: Max Luebbering Date: Mon Feb 26 15:24:25 2024 +0100 refactor: fixed the library usage exampe commit cd2128d12bf2d990e0a6cd7ade48832d31dd8bc4 Author: Max Luebbering Date: Mon Feb 26 15:24:00 2024 +0100 refactor: replaced absolute paths with relative ones commit 9ab6654d507b49e3203622a3c3ef259f2fd11b02 Author: Max Luebbering Date: Mon Feb 26 15:23:06 2024 +0100 fix: fixed add_custom_component in Main commit 64b785a97c1f261f3421ca126f89307084c13c4e Author: Felix Stollenwerk Date: Mon Feb 26 14:18:06 2024 +0100 fix: skipping of tests in non-distributed environment commit c7f7a7b924cff509564dc7eb550c3c556f782409 Author: Max Luebbering Date: Mon Feb 26 13:35:24 2024 +0100 chore: minor changes in TestFSDPToDiscCheckpointing commit 10538ac07e8e4f021b08499245aca69deb5afed4 Author: Max Luebbering Date: Mon Feb 26 13:24:10 2024 +0100 refactor: also using ComponentEntity now in the tests commit 432426bacfbd333b1a6963a2695cd9e81a99b8b8 Author: Max Luebbering Date: Mon Feb 26 13:23:46 2024 +0100 refactor: fixed failing test_e2e_training_run_wout_ckpt commit 63829e1f7642456501a56d22b1be2230d1576b61 Author: Max Luebbering Date: Mon Feb 26 13:20:43 2024 +0100 chore: excluded openGPTx from test cov commit 12632fd18540651dfd675bb814096210c2e1df19 Author: Max Luebbering Date: Mon Feb 26 12:57:16 2024 +0100 refactor: introduced ComponentEntity commit c15de178b8b49fdb8be711abd4ffb073519af91b Author: Max Luebbering Date: Mon Feb 26 12:09:14 2024 +0100 refactor: various smaller changes commit 973909d23d39ddbdb73973fd458f9f9737db5d8f Author: Felix Stollenwerk Date: Mon Feb 26 10:58:27 2024 +0100 refactor: sort classes in config commit bc64ee05f56df84c00580617ef124a22793e9fae Author: Felix Stollenwerk Date: Mon Feb 26 10:52:21 2024 +0100 refactor: remove RegistryFactory commit b9dbe2e501916491b3940e216f1033338cf24b47 Author: Felix Stollenwerk Date: Mon Feb 26 10:19:24 2024 +0100 refactor: rename and fix readme for getting started example commit ca74340ff831d90b1fe9f7c9c9b82b21c4e411fa Author: Max Luebbering Date: Sun Feb 25 16:00:17 2024 +0100 feat: added activation checkpointing to __main__.py commit 7ae22343ecfb4d77039b4905c7a59e827c18a5d1 Author: Max Luebbering Date: Sat Feb 24 21:19:44 2024 +0100 refactor: fixed some of the configs commit bcd6e5b408aa0f17b1b2954e492cfe0e3e083e22 Author: Max Luebbering Date: Sat Feb 24 21:16:08 2024 +0100 feat: experiment_id now set in the config via omega conf resolver commit a6ea22a783a014564f525f74ec366d640bc651dd Author: Max Luebbering Date: Sat Feb 24 14:03:46 2024 +0100 refactor: gpt2 config for checkpointing tests commit ff3eb525ce0222e15d49f4101481b0bcddb3b62e Merge: 64617dd fb0aea5 Author: Max Luebbering Date: Sat Feb 24 14:01:15 2024 +0100 chore: Merge branch 'hierarchical_instantiation' of github.com:Modalities/modalities into hierarchical_instantiation commit 64617ddfea06354dec887cdd76d0ed6b7d9cbb6d Author: Max Luebbering Date: Sat Feb 24 14:00:35 2024 +0100 feat: added add_custom_component function to Main commit df4f971c8652e44cdde9a7451445b77386dbfd2c Author: Max Luebbering Date: Sat Feb 24 13:59:33 2024 +0100 test: fixed fsdp test, but cannot be run directly via pytest as it needs torchrun commit fb0aea5fe1fdb04dae9f6299ed2b5f39b4a3be17 Author: Felix Stollenwerk Date: Sat Feb 24 10:51:51 2024 +0100 fix: replace conint/confloat correctly commit fd07cb084fcc56d6c380fcbee0d04599b900c4fe Author: Max Luebbering Date: Fri Feb 23 19:39:12 2024 +0100 refactor: made base_model_to_dict public as it is great for testing commit aa0d64f7e499bd626930aaf7953f59b73a37cc67 Author: Max Lübbering <2804731+le1nux@users.noreply.github.com> Date: Fri Feb 23 18:31:54 2024 +0100 Update README.md commit e70f3a0e5742b3e6d4e5e014628fad4fe765ea44 Author: Felix Stollenwerk Date: Fri Feb 23 17:57:15 2024 +0100 fix: replace conint/confloat for pydantic 3.0 compatibility commit 70d9e632eb4a21f33ad139fd8b53871371345ee3 Author: Max Luebbering Date: Fri Feb 23 17:40:38 2024 +0100 chore: more documentation commit 23960207055613fcf2f402e9d2519e3f5fe5f40f Merge: a68ddf4 021b7c2 Author: Max Luebbering Date: Fri Feb 23 17:39:09 2024 +0100 chore: Merge branch 'hierarchical_instantiation' of github.com:Modalities/modalities into hierarchical_instantiation commit a68ddf4fe949e741eed4f94351ab13ece017232d Author: Max Luebbering Date: Fri Feb 23 16:57:23 2024 +0100 feat: added example for registering a custom component commit 021b7c265c313a4fa096637187d7b4bf97459cc3 Author: Felix Stollenwerk Date: Fri Feb 23 11:38:32 2024 +0100 refactor: restored base_model_to_dict commit b619b4157401d18bebbfc29c20bfa2343f8aea14 Author: Felix Stollenwerk Date: Fri Feb 23 09:32:31 2024 +0100 refactor: replace base_model_to_dict by pydantic built-in method commit 34c6498622a4dd745aa4c15ca09bd17358612d2a Author: Felix Stollenwerk Date: Fri Feb 23 09:26:44 2024 +0100 refactor: fixed typing for registry commit 52ffea49524bd6bf87e923c9b4472628957446a4 Author: Max Luebbering Date: Thu Feb 22 17:59:20 2024 +0100 fix: fixed failing end 2 end test commit b0bd29669611212e63c311775ccb9851e42c7b51 Author: Max Luebbering Date: Thu Feb 22 17:58:38 2024 +0100 fix: eval_dataloaders are now treated as list instead of dict. This was not reflected yet in the subscriber factory commit cbf905b497de88db970d5e436b09c4fbe8a2eb2c Author: Max Luebbering Date: Thu Feb 22 17:47:53 2024 +0100 fix: checkpointing test commit a42a47922e81fbb15f4c8039cc766f515ef5edea Merge: 26b8b82 e3b50f6 Author: Max Luebbering Date: Thu Feb 22 17:33:21 2024 +0100 chore: Merge branch 'hierarchical_instantiation' of github.com:Modalities/modalities into hierarchical_instantiation commit 26b8b824384c278d793c79920f9bad447b93b0e1 Author: Max Luebbering Date: Thu Feb 22 17:32:41 2024 +0100 refactor: we fully support the configs again for hierarchical instantiation commit 9dfd100e3525c511271aa65d731253e6a49d6afe Author: Max Luebbering Date: Thu Feb 22 17:31:45 2024 +0100 refactor: eval_dataloaders are subsumed in a list now commit e3b50f65d8662ebd9306c3cc84b76629861d7c81 Author: Felix Stollenwerk Date: Thu Feb 22 12:39:17 2024 +0100 refactor: unification of Pydantic*IF classes commit 7c4fafbf0a8d46b8f88505ece04a473934f60f78 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Thu Feb 22 09:24:42 2024 +0000 chore: enabled pytest discovery with all tests. Some tests still need to be fixed! commit 34dc796c7da34945db498849519a88cbd8c0406b Author: Felix Stollenwerk Date: Thu Feb 22 10:24:09 2024 +0100 refactor: renaming for consistency commit 2d8349d7b4034190eeffb0a223582f2344125916 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Thu Feb 22 08:45:23 2024 +0000 fix: e2e test commit cc60608c7793d21d9e866f35dfdd25f10fd03826 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Thu Feb 22 08:10:43 2024 +0000 fix: set FIXME for fsdp_to_disc_checkpointing_test and fix oudated config test commit fdfb90a4c2e480d90298cc70aeb37f9bcdcf58b5 Author: Max Luebbering Date: Wed Feb 21 19:03:04 2024 +0100 chore: fixed variable naming commit 1de69c355174e34c0cc99e3d5006b36020635243 Author: Max Luebbering Date: Wed Feb 21 18:59:23 2024 +0100 refactor: merged remote to local and refactored callback_interval_in_batches to callback_interval_in_samples in the config commit e1dd046e06c822d2bf11b04b805b449c0cb2e5e8 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Wed Feb 21 15:22:33 2024 +0000 fix: test discovery under vscode. TODO: replace PretrainedGPTConfig by correct class commit cd5ec46dc408f1ca5e6d70aa278a70d074ada1a0 Merge: 281f20f e16dec9 Author: Max Luebbering Date: Wed Feb 21 13:15:31 2024 +0100 chore: Merge branch 'hierarchical_instantiation' of github.com:Modalities/modalities into hierarchical_instantiation commit 281f20f50512fe5e7715261e29e5d619bc20f3fa Author: Max Luebbering Date: Wed Feb 21 12:56:45 2024 +0100 refactor: moved LookupEnum to dedicated file to fix circular imports commit e433913a7d58459572689e6e4b601d9e9cf1e96c Author: Max Luebbering Date: Wed Feb 21 12:55:34 2024 +0100 refactor: removed types.py commit 2c3762bc8d6474adb3a3528a2227abd8399587bb Author: Max Luebbering Date: Wed Feb 21 12:48:12 2024 +0100 chore: import fix commit 4f07fc9790c623e447590e85029127495c148ac1 Author: Max Luebbering Date: Wed Feb 21 12:47:50 2024 +0100 feat: added checkpointed model and fsdp wrapped model to registry factory commit 2ba8edd9df1910c4fde448b859eb351737bc8fab Author: Max Luebbering Date: Wed Feb 21 12:46:26 2024 +0100 chore: fixed import in registry factory commit 76b4240966d50a04f194874195717596b41f6703 Author: Max Luebbering Date: Wed Feb 21 12:46:03 2024 +0100 chore: minor fix commit 417e0edf75ed8a3d26c3d75e849d23ee51aa87b5 Author: Max Luebbering Date: Wed Feb 21 12:45:48 2024 +0100 refactor: deleted checkpointing factory commit b056ddda129d9e7e905641b52ca4b273d84b15fb Author: Max Luebbering Date: Wed Feb 21 12:45:09 2024 +0100 refactor: we always instantiate the LLMDataloader with a ResumableBatchSampler now commit cd5e6fea8fc9ef8331a90bcc1c3827cbf224b375 Author: Max Luebbering Date: Wed Feb 21 12:43:20 2024 +0100 refactor: config_new.py renamed to config.py commit f39051f797df9b4361240b373dd04d31405b35cb Author: Max Luebbering Date: Wed Feb 21 12:41:48 2024 +0100 refactor: deleted lookup_types commit c971bb09789b58d481cf064f69aa44e49a867c8b Author: Max Luebbering Date: Wed Feb 21 12:39:47 2024 +0100 refactor: removed resolver_register commit 3371b39b2d64d4f43c7de6d035444030c774ee05 Author: Max Luebbering Date: Tue Feb 20 21:37:11 2024 +0100 refactor: __main__.py now is capable of instantiating hierarchical configs commit b5f3d4dc3e9df9bc46184ab1912a2aee6a7bbc02 Author: Max Luebbering Date: Tue Feb 20 21:34:25 2024 +0100 refactor: refactored FSDPToDiscCheckpointing to use ModelFactory.get_fsdp_wrapped_model commit 29aee7d5aa5d434e7ec5609fa81a04a509a3db99 Author: Max Luebbering Date: Tue Feb 20 21:33:06 2024 +0100 chore: ProcessGroupBackendType inherits now from LookupEnum commit 197f8638d7efe1d5f08527abd3972f6c3b7e7005 Author: Max Luebbering Date: Tue Feb 20 21:32:36 2024 +0100 feat: implemented OptimizerFactory commit 8d1bb9e822b9b68dab91a94b568bc83b0a13abdd Author: Max Luebbering Date: Tue Feb 20 21:32:12 2024 +0100 feat: added model factory commit 8b9dc20d34f5ebfec8debe3ba0df0392f113e259 Author: Max Luebbering Date: Tue Feb 20 21:31:40 2024 +0100 feat: introduced CudaEnv commit 89fa61cdd38f918bf925c2b1cfe92943837e75cc Author: Max Luebbering Date: Tue Feb 20 21:31:15 2024 +0100 chore: MixedPrecisionSettings inherits now from LookupEnum commit 4037db24f5daaf9c37608f2c22c211c1070c73e5 Author: Max Luebbering Date: Tue Feb 20 21:30:50 2024 +0100 refactor: removed running env commit eb9f5b5c61753da8b5f2904fd2956b3eb6d4c520 Author: Max Luebbering Date: Tue Feb 20 21:30:32 2024 +0100 feat: added Settings basemodel to config and refactored FSDPToDiscCheckpointingConfig commit c60d689bb8d7432b1284f650f25b04fb2ea694e2 Author: Max Luebbering Date: Tue Feb 20 21:29:29 2024 +0100 refactor: restructured config lorem ipsum commit d9d8925f98b5cd0220179fc8f11423842a9be2b3 Author: Max Luebbering Date: Tue Feb 20 20:22:49 2024 +0100 fix: bug fix in component factory commit e16dec9b7e100962a9e0cf6f59f5db14371c82fd Merge: 4c17abb d71bceb Author: Felix Stollenwerk Date: Mon Feb 19 20:58:25 2024 +0100 chore: merge main into hierarchical_instantiation commit 4c17abbd7c440207342ed80929afdfca40a7edc0 Author: Felix Stollenwerk Date: Mon Feb 19 15:52:59 2024 +0100 refactor: unification of component registry and config registry commit d71bceb661b560ee9ccaa2a1ed25a089947d095c Author: Alexander Weber Date: Mon Feb 19 15:29:09 2024 +0100 Update README.md commit 95bfc55c0d322bed427527c4b1e71e83e71a6fef Merge: f16c409 a0b799a Author: Alexander Weber Date: Mon Feb 19 15:25:29 2024 +0100 Merge pull request #52 from Modalities/chore/add-pytest-coverage chore: add pytest coverage commit a0b799a3fd648647f021f3acea8429bf4dea4e78 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 14:22:49 2024 +0000 chore: clean gitignore commit 5361ca584e181a46f23e9ba515b8e7cb7e0d80a0 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 14:11:00 2024 +0000 chore: add toml support commit 4047b674b1fb943db1bfaa223230d94629a91226 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 14:07:13 2024 +0000 chore: try fix from 2021 commit 20b1460be6edffc9fc59fdc2b56cae7c544d5d6a Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 13:54:42 2024 +0000 chore: remove outdated .coverage.toml commit ec495a3264cfa1b097c245013efbbd70c7800fed Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 13:45:58 2024 +0000 chore: remove --cov from github action commit 920ccab619fc6c49ca394322cb4d7b535b6f313b Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 13:45:13 2024 +0000 chore: add coverage options in pyproject.toml commit a3ce9b16c278df7a81da22976629253a1c1f0f8f Author: Max Luebbering Date: Mon Feb 19 14:44:59 2024 +0100 feat: integrated message subscribers commit b324c3f5520ebfe5caf8db0dd55a5aa4ac5177d8 Author: Max Luebbering Date: Mon Feb 19 14:41:37 2024 +0100 refactor: refactored dataloader and its factory commit f686268de7e7301eb630b073b102bb3e34ba254d Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 13:41:30 2024 +0000 chore: add pytest --cov arguments by default commit f1e315538f7e3ff487b5d103ae4d9dfb5ed099ca Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 13:36:22 2024 +0000 chore: search for coverage bug commit 4122c6c6de49eadec85c0e6b7210a5d3cd12c9aa Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 13:31:43 2024 +0000 chore: search for coverage bug commit f43b81fcfeb2eec72f3f08c0f7b69ae68031d8d0 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 13:08:27 2024 +0000 chore: fix coveralls github action commit 81292e8953ab2736ee301b023d82b3f98095e4f7 Author: Max Luebbering Date: Mon Feb 19 14:03:40 2024 +0100 refactor: moved OpenGPTXDatasetWrapper to DatasetFactory commit bc56246ddcd4f071f270bb880ae2941f5f0fcd0d Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 12:41:54 2024 +0000 chore: add pytest-cov execution as github action commit f16c409bd6169d7ecba838ae62fc12230fec45f9 Merge: a0513e3 bc03021 Author: Alexander Weber Date: Mon Feb 19 11:05:36 2024 +0100 Merge pull request #56 from Modalities/fix/tests fix: use renamed tokenizer file name commit bc03021b2039c4d84ccc9ce1ed6bee24521f6ea8 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 09:48:47 2024 +0000 fix: use renamed tokenizer file name commit a0513e39ab5988c0bdbd1b9b18997f56ac49844f Merge: b8117b1 76e0518 Author: Alexander Weber Date: Mon Feb 19 10:26:45 2024 +0100 Merge pull request #38 from Modalities/fix/tests-on-cpu commit 76e05181c4cf2b087b126fea09e4ab9573b7bfe6 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 09:24:48 2024 +0000 chore: moved if statement into torch.device commit b8117b101f591bfc253edc801c2c3f6c5c34cb88 Merge: 1c99963 78b9645 Author: Alexander Weber Date: Mon Feb 19 10:11:56 2024 +0100 Merge pull request #42 from Modalities/fix/linting fix: lint all files commit 78b964500d87b9dc413fac86a35f8b5ea4155960 Merge: 5b60c2f 1c99963 Author: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon Feb 19 09:05:44 2024 +0000 chore: local merge commit 22676058d92ef6dda5d9acd51b6b9c69d2175335 Author: Max Luebbering Date: Sun Feb 18 23:27:27 2024 +0100 feat: towards subscriber support with hierarchical instantiation commit a4491195d5189aeadd79dbdc1914c1afab6524d5 Author: Max Luebbering Date: Sun Feb 18 23:25:40 2024 +0100 chore: minor changes commit aab3fa2bc641d3154d174e69cb2b1323ae80e65b Author: Max Luebbering Date: Sun Feb 18 23:24:58 2024 +0100 feat: implemented subscriber factory commit 1c99963ada873abad70e3a955226ffe15d7f5d1c Merge: a8b6563 cf27873 Author: Max Lübbering <2804731+le1nux@users.noreply.github.com> Date: Sun Feb 18 22:45:14 2024 +0100 Merge pull request #29 from Modalities/feat/contrastive_loss Add Noise Contrastive Estimation Loss commit 6baf221b576f8e688b5c1b40f699180e9012870d Author: Max Luebbering Date: Sat Feb 17 17:11:11 2024 +0100 feat: added LLM dataloader support commit 8ab04a5fae82094d6950e1e1a517383ea8c5a736 Author: Max Luebbering Date: Sat Feb 17 17:10:16 2024 +0100 feat: introduced CollateFnIF for colleate functions commit 018c2785f4ddc03c4199571c18426a213338ca78 Author: Max Luebbering Date: Sat Feb 17 17:00:02 2024 +0100 feat: added resumable batch sampler commit 1273c319f48c367635eed298ed25c7a7d4fb9b57 Author: Max Luebbering Date: Sat Feb 17 16:53:57 2024 +0100 feat: added gpt_2 collator support commit 536447ce8c139adfbd239b0c2ffe2f5118dc2327 Author: Max Luebbering Date: Sat Feb 17 16:44:57 2024 +0100 feat: added batch sampler support commit 771eab11d18c1c963c2970382496857103fb7fc3 Author: Max Luebbering Date: Sat Feb 17 16:18:51 2024 +0100 feat: added PydanticDatasetIF for SamplerConfig commit f1c1be4a9c635771d54ad2d58057cc9a4ee29f1b Author: Max Luebbering Date: Sat Feb 17 15:55:47 2024 +0100 feat: added support for the different dataset formats commit 0824bb00329b46655671df5c9ce0146fddc8005f Author: Max Luebbering Date: Sat Feb 17 15:38:51 2024 +0100 refactor: added adaptations that were injected in the dataloader factory previously commit 6985fad90714907a3c0c5a47a72dbe0ab55b995c Author: Max Luebbering Date: Sat Feb 17 15:26:22 2024 +0100 feat: implemented dataset factory for various dataset types commit 81022f4f4c5d61fcc96a357fb12e5545f15d6635 Author: Max Luebbering Date: Sat Feb 17 14:33:37 2024 +0100 feat: added gpt2 tokenizer support commit 55c01104190bd83e08c6edba086cdc44c4d51ae2 Author: Max Luebbering Date: Sat Feb 17 13:35:19 2024 +0100 feat: added adamw support commit 4a6a415359c03e9efd41c8ed40e19ad057cb5839 Author: Max Luebbering Date: Sat Feb 17 13:32:30 2024 +0100 feat: implemented OptimizerFactory commit c2bd570950ecffc0256f5536e56907b26114f2e2 Author: Max Luebbering Date: Sat Feb 17 13:31:59 2024 +0100 fix: added root-level to dict function for basemodel to prevent recursive model dumps commit 90207edc611478dab5f0ceb4eafbd86e075639ec Author: Max Luebbering Date: Fri Feb 16 20:05:47 2024 +0100 refactor: started refactoring the lorem ipsum config towards the new hierarchical configs commit 13042413b40b3091374cf9946cb66c43f9d5e213 Author: Max Luebbering Date: Fri Feb 16 20:05:24 2024 +0100 refactor: Main makes partially use of the hierarchical instantiation now commit f7dfe3123ded20435c217091a6eeae6aa4a080d4 Author: Max Luebbering Date: Fri Feb 16 20:04:54 2024 +0100 refactor: Refactored CheckpointingFactory commit 38499c41f207baf6aeb2152c56e9cbbe05f5fe72 Author: Max Luebbering Date: Fri Feb 16 20:04:26 2024 +0100 refactor: removed unused atribute in Checkpointing commit 542ba75e849f881c1aae36b2ccb2f05a98a17596 Author: Max Luebbering Date: Fri Feb 16 20:04:08 2024 +0100 fix: bugfix in component factory commit d446260353c2e290e3fe2654062689bd289151ea Author: Max Luebbering Date: Fri Feb 16 20:03:54 2024 +0100 feat: added new configs in separate file for now commit 6d121f39d13167fb98454647aa0a0575180d9cc4 Author: Max Luebbering Date: Fri Feb 16 20:03:18 2024 +0100 feat: added more components to registry factory commit fb3b35f7bcade0f057effd26ad26b2381ad2d792 Author: Max Luebbering Date: Fri Feb 16 20:02:48 2024 +0100 refactor: refactored FSDPRunningEnvConfig commit 8eda99c93553845e9da90f8bcae2076356879b71 Author: Max Luebbering Date: Thu Feb 15 23:37:33 2024 +0100 refactor: refactored component factory to use the registry commit 41be77311a079a98231a6df0019bc448937e5c28 Author: Max Luebbering Date: Thu Feb 15 23:36:57 2024 +0100 feat: added registry factory commit 3ebb65671a3b92fea3440e133bcf8e523596c7f6 Author: Max Luebbering Date: Thu Feb 15 23:36:34 2024 +0100 feat: implemented registry commit f2164a83b529cac831457573f2aaf48967ed53dd Author: Max Luebbering Date: Thu Feb 15 23:36:00 2024 +0100 test: configs now use the new format without typehints commit 5fb2199df16b6496e528ac31d5f4cd472c3a3674 Author: Max Luebbering Date: Thu Feb 15 23:35:39 2024 +0100 test: added registry testing commit 623f847771418edaf3542beeacda8c73f7654d81 Author: Max Luebbering Date: Thu Feb 15 23:34:49 2024 +0100 test: updated test configs to the new format commit 372947b3766978b9c2ebce994dd2da3d20e721f3 Author: Felix Stollenwerk Date: Wed Feb 14 21:45:11 2024 +0100 chore: add pytest coverage (locally) commit 36bc7ae8a36cc3c342d8d13ef5928d11b7a4bcd1 Author: Max Luebbering Date: Wed Feb 14 13:45:11 2024 +0100 refactor: renamed config_types to custom_config_types in ComponentFactory commit babd5970fdcaf79fc40b6fc3b366986194050361 Author: Max Luebbering Date: Wed Feb 14 13:35:58 2024 +0100 feat: added support custom types in component factory commit 1639a6aad855f337560fae94ded62aa3aa37208c Author: Max Luebbering Date: Wed Feb 14 12:04:00 2024 +0100 refactor: simplified ComponentFactory commit aa9e04043b595e66a58f856a9d35d9b42385fefe Author: Max Luebbering Date: Wed Feb 14 10:33:11 2024 +0100 test: removed code duplication in test_component_factory commit 44677c6007aca6183b98f43d03fc3574057927de Author: Max Luebbering Date: Wed Feb 14 10:30:43 2024 +0100 test: refactored test_custom_component commit 71de3ff3f0426bf0a56fcf4b7c75846226aaa55c Author: Max Luebbering Date: Tue Feb 13 21:44:53 2024 +0100 test: added testing for custom components commit 2a54f84a59bfa125499d0a05a2624163be968d9c Author: Max Luebbering Date: Tue Feb 13 20:57:25 2024 +0100 test: added test yaml configs for component factory commit 35236d093625402728d6e6693f4bd7ca9fca7437 Author: Max Luebbering Date: Tue Feb 13 20:56:22 2024 +0100 test: implemented test_non_existing_reference commit bb4bcb3fbbb5feeaa3c5306b4060fa882de7b912 Author: Max Luebbering Date: Tue Feb 13 20:53:21 2024 +0100 test: implemented test_component_filter commit 0dfbbcbee9f1c19920777e77c06d7d6f0140ac99 Author: Max Luebbering Date: Tue Feb 13 20:49:36 2024 +0100 test: implemented test_hierarchical_component_instantiation commit 3a66b6527f49d026361a5cd7cab4121672d3d465 Author: Max Luebbering Date: Tue Feb 13 20:41:18 2024 +0100 test: implemented forward and backward referencing test commit c0c877ca5e177487067c1474f63dfd0a4485c125 Author: Max Luebbering Date: Mon Feb 12 14:52:13 2024 +0100 chore: fixed imports in component factory commit a9781a31da91f0cb42e55505b3b46ccda30e16fc Author: Max Luebbering Date: Mon Feb 12 14:42:28 2024 +0100 refactor: added drafted test code for component factory commit b1cbb46fb932dac6ab54df7b5f76ca1689d86361 Author: Max Luebbering Date: Mon Feb 12 14:42:00 2024 +0100 refactor: moved trial component factory code to test module commit c115b2b3b64a689cbb0a868aaaf1c1789e9da343 Author: Max Luebbering Date: Mon Feb 12 14:41:21 2024 +0100 refactor: moved component factory into parent module commit e678d78558c4f539b1790b44efd763e61379421f Author: Max Luebbering Date: Mon Feb 12 14:26:55 2024 +0100 refactor: renamed hierarchical DI module to hierarchical_instantiation commit 45f7ff4149b13c3cb45789b8584933ba9932574d Author: Max Luebbering Date: Mon Feb 12 14:19:24 2024 +0100 refactor: removed legacy code and added comments to component factory. commit da88895a0f296fcc72fcfb27deb7ab4f7f510d90 Author: Max Luebbering Date: Mon Feb 12 14:03:26 2024 +0100 feat: added referencing to config commit b42aeeb420114e3037632050235a826e294acc79 Author: Max Luebbering Date: Mon Feb 12 14:02:50 2024 +0100 feat: added ReferenceConfig and PassType commit 72f05246428615614b557c1f5026fcfe17f8e68e Author: Max Luebbering Date: Mon Feb 12 14:02:24 2024 +0100 feat: implemented forward and backward component referencing commit 43e113474bc40e09910fc7aad9904dc0e6faa537 Author: Max Luebbering Date: Sun Feb 11 19:56:18 2024 +0100 chore: added documentation for generate_text text CMD interface commit cf2787367a631f0d4659de56bd9ff3f2decc75df Author: Sogol Haghighat Date: Fri Feb 9 17:38:31 2024 +0100 refactor: adapt nce_loss function to reflect loss from CoCa paper commit d388d210f4f6cd083abc5b3b3606b824b6ba40ca Author: Sogol Haghighat Date: Fri Feb 9 17:37:35 2024 +0100 test: adapt test_nce_loss_correctness to uni and bidirectional loss commit a8b6563a22551a296f569f8f570bdd017f7d7c8b Merge: da65493 00e10ae Author: Max Lübbering <2804731+le1nux@users.noreply.github.com> Date: Fri Feb 9 16:46:36 2024 +0100 Merge pull request #30 from Modalities/huggingface_models_support feat: Generic huggingface transformer support commit 00e10aef624a9ab672f1f79f9ab11ff7c29ad297 Author: Max Lübbering <2804731+le1nux@users.noreply.github.com> Date: Fri Feb 9 16:24:02 2024 +0100 Update preprocess_dataset.py commit e93e767226490f94f008741262f289c26fe70027 Merge: f435fc8 da65493 Author: Max Lübbering <2804731+le1nux@users.noreply.github.com> Date: Fri Feb 9 15:50:03 2024 +0100 Merge branch 'main' into huggingface_models_support commit f435fc8e9d00faaee612b0487d93c7f6e39424f1 Author: Max Luebbering Date: Fri Feb 9 15:46:59 2024 +0100 feat: introduced huggingface_prediction_subscription_key to HuggingFacePretrainedModelConfig to support different output formats commit e6f4aac25db1a064c07e7031728a08093ac6c043 Author: Max Luebbering Date: Fri Feb 9 15:46:08 2024 +0100 refactor: moved lookup_enum to dedicated file. commit ebbe8c5243af49f400d052cf1b4388475f393765 Author: Sogol Haghighat Date: Fri Feb 9 13:49:42 2024 +0100 test: add test for nce_loss using a manually calculated example commit 7d5c09559ee18601282322d9089a3ea8606a1b53 Author: Max Luebbering Date: Thu Feb 8 20:17:36 2024 +0100 chore: removed legacy code commit dad3ea4fd2d5ec2403a243ab6b33e8c0117baaf3 Author: Max Luebbering Date: Thu Feb 8 20:16:40 2024 +0100 chore: added legacy trials for hierarchical DI commit 3ab9ff363b0c12d07bcca8edfd781fc54ba732d7 Author: Max Luebbering Date: Thu Feb 8 20:14:10 2024 +0100 chore: added __init__.py commit 3dfdb2acc9d2570a5e4b33750ef4252e2affa759 Author: Max Luebbering Date: Thu Feb 8 20:13:51 2024 +0100 feat: implemented factory for hierarchical component instantiation commit dc7c1a2d41a7434782d9e057fb6c64f3be6efdc7 Author: Max Luebbering Date: Thu Feb 8 20:13:17 2024 +0100 feat: added example yaml config file for hierarchical instantiation commit 099979be89e1addbd456515741c7b2269503bbf5 Author: Max Luebbering Date: Thu Feb 8 20:12:58 2024 +0100 feat: added configs for the test components commit c4292ce27c729b12b2ea6b58f484d780e5c5ad2c Author: Max Luebbering Date: Thu Feb 8 20:12:25 2024 +0100 feat: added components for testing commit fc5cb963d3b3e6aefd3b10dba8882b231fb5f78e Author: Max Luebbering Date: Thu Feb 8 20:11:28 2024 +0100 chore: minor debugging improvement in parse_enum_by_name in utils commit 783ad8197c64628774d1455642e5b06b01772342 Author: Max Luebbering Date: Thu Feb 8 20:10:57 2024 +0100 chore: removed legacy trials commit 9095ac563382abc04cf78d309daf127f9e0a0e58 Author: Luzian Hahn Date: Tue Feb 6 16:17:22 2024 +0100 docs: update times in table after perf upgrade commit 91ec38e96cc0ceb240aa700d563aaa4d6dbd70cc Author: Luzian Hahn Date: Tue Feb 6 16:07:46 2024 +0100 fix: make encoding specification obsolete and improve perf of index creation commit afae8589e864741055dff1d4474b9ff80ab6c92f Author: Luzian Hahn Date: Tue Feb 6 15:48:19 2024 +0100 feat: make encoding configurable commit 71f77e2424448bc22fd2a20f845c6521efcbf024 Author: Luzian Hahn Date: Tue Feb 6 14:51:57 2024 +0100 refactor: remove parameter-artifact commit a6686200909a44fc98c72f1030f326489a408012 Author: Luzian Hahn Date: Tue Feb 6 14:47:52 2024 +0100 refactor: remove TODO-artifact commit a08518f9bfd55bbe4258fabb3699b9f8bdd21db7 Author: Luzian Hahn Date: Tue Feb 6 14:43:31 2024 +0100 refactor: rename queue for token-writing commit 2e535a372fd4e77dd17a5badf569234559f08c96 Author: Luzian Hahn Date: Tue Feb 6 14:25:35 2024 +0100 fix: derive default value for cpu count automatically commit 03d3f47e94784478c579fa893eb77acc4c0eec15 Author: Luzian Hahn Date: Tue Feb 6 14:24:48 2024 +0100 perf: share FileIOStream among process calls - not threadsafe! commit bc086caaa9dcd5be683742d2913b478f0c1c304e Author: Luzian Hahn Date: Tue Feb 6 14:13:12 2024 +0100 docs: remove auto execution of benchmarks, while sourcing bench utils commit fb04dc8fee02c42ea163b0d5b8889d53057dd254 Author: Luzian Hahn Date: Tue Feb 6 14:08:14 2024 +0100 fix: typo in warning commit faa2eff5608edcf2167cb9acbc5c593331e6bef4 Author: Luzian Hahn Date: Tue Feb 6 14:05:35 2024 +0100 docs: unify time units in measurement table commit 26ade7c7eaf8043521e7fa1d831d3fb1985c8bb0 Author: Luzian Hahn Date: Tue Feb 6 13:37:22 2024 +0100 docs: add definitions of benchmarking experiments commit 463872d8af31ed61a82195c64e2e051733eea652 Author: Max Luebbering Date: Mon Feb 5 18:55:00 2024 +0100 refactor: drafted hierarchical instantiation commit bd39244e588e47f7467d2dac4ce391d52821745f Author: Max Luebbering Date: Mon Feb 5 18:52:20 2024 +0100 chore: removed unused properties in config.py commit a908e7a1089e8b18dcb25ef9f1fef6ffbaefb227 Author: Max Luebbering Date: Mon Feb 5 18:50:26 2024 +0100 refactor: moved resolver register commit 540afe2a177cd6ab2a4c8cac94bc3b5228c32504 Author: Sogol Haghighat Date: Thu Feb 1 17:04:22 2024 +0100 refactor: add keyword arguments commit 57ccaf9db2a621b974750ff8a9c729ccda367d37 Author: Sogol Haghighat Date: Thu Feb 1 17:03:18 2024 +0100 refactor: introduce nce_loss function and add asymmetry parameter in NCELoss commit 35ca235323758948ca3a9385e2001e7506587a25 Author: Max Luebbering Date: Thu Feb 1 15:57:13 2024 +0100 feat: drafted hierarchical instantiation commit 5b60c2f771bd0563644b054d171999540fa9ffc0 Author: Felix Stollenwerk Date: Tue Jan 30 22:48:35 2024 +0100 fix: lint all files commit d84353f876deadd1a06d90b42d429b43803abdc5 Author: Luzian Hahn Date: Tue Jan 30 17:11:02 2024 +0100 docs: add details about dataloading performance benchmarks commit 93d924163486f4e8064ef30cc617ef03643cf7fe Author: Luzian Hahn Date: Tue Jan 30 17:10:12 2024 +0100 perf: use one large memmap for PackedDatasets commit e6cb130dcde391fca5843fb9aaa24ca0642f66dc Author: Sogol Haghighat Date: Tue Jan 30 16:18:50 2024 +0100 refactor: apply ruff refactor comment commit dfbefcbd372d28173c2faeaa4be62026f108dbe9 Author: Felix Stollenwerk Date: Tue Jan 30 15:23:42 2024 +0100 fix: get rid of reduce mocking (for testing) commit f4e3c563f6265905f83fcfb84b8444bdd1088159 Author: Felix Stollenwerk Date: Tue Jan 30 15:17:10 2024 +0100 fix: training and evaluation on CPU (for testing) commit 69e2050f8a7375ace47bfd6d28f3e1fb85f8780b Author: Luzian Hahn Date: Tue Jan 30 12:14:21 2024 +0100 feat: infer smallest tokensize automatically for packing commit a96a5f4bfd843bcdb3d5aeff3de2aeb296632cb8 Author: Luzian Hahn Date: Tue Jan 30 09:17:35 2024 +0100 perf: use parallelized tokenization when creating .pbin files commit ee08a01e0015ea2334b30df490334cf33f1ae9a7 Author: Luzian Hahn Date: Mon Jan 29 15:35:55 2024 +0100 perf: increase memmap index creation speed commit 8e30e008fd376964fa22221cddbb937f9bc5cc49 Author: Max Luebbering Date: Sun Jan 28 23:22:39 2024 +0100 chore: added documentation commit abb63aad014a384e252ac3d28a73f010d06231af Author: Max Luebbering Date: Sun Jan 28 22:08:45 2024 +0100 refactor: fixed configs due to latest changes commit f83da11f231b0c50635013bd2643aa38e972575c Author: Max Luebbering Date: Sun Jan 28 22:07:26 2024 +0100 feat: wired up huggingface transformer models commit 93095055a651331e209dcec39882359eb664a8ce Author: Max Luebbering Date: Sun Jan 28 22:01:29 2024 +0100 chore: renamed Block to GPT2Block commit 4d6a5ffb18ab590e6b6bf088ca0bf36408823c7a Author: Max Luebbering Date: Sun Jan 28 22:01:17 2024 +0100 feat: fully implemented HuggingFacePretrainedModel with respective configuration commit 88c4fdb479c4649e0fdfd0016b40d0b7ae9066ec Author: Max Luebbering Date: Sun Jan 28 22:00:33 2024 +0100 feat: implemented automatic FSDP wrapping commit 3b51117b9cd63d7805d8818547a8983f84169c8a Author: Max Luebbering Date: Sun Jan 28 21:56:48 2024 +0100 refactor: renamed tokenizer.json to tokenizer_gpt2.json commit 95e67a0ea8273c63c464973be9bc6508ffc8b72b Author: Max Luebbering Date: Sun Jan 28 21:55:52 2024 +0100 feat: renamed redpajama memmap datasets (added tokenizer info) commit 0992d2100ddfcb716bb16f2e5dd80e283ffa24fc Author: Max Luebbering Date: Sun Jan 28 00:26:36 2024 +0100 feat: towards generic huggingface transformer support commit ba65580ab21dfafd2b989c89cdafb3b2e1ec50c2 Author: Sogol Haghighat Date: Fri Jan 26 13:36:38 2024 +0100 refactor: refactor docstrings commit e459321806f8b5db752b3a15692bf3eef7ff3c80 Author: Sogol Haghighat Date: Thu Jan 25 17:48:32 2024 +0100 test: add test for contrastive loss commit bb14749824d9d0af3ca909278265bdc6e2a11e4f Author: Sogol Haghighat Date: Thu Jan 25 17:47:43 2024 +0100 feat: add contrastive loss for coca model training commit c9e4e084cb4f9a4f9fab89e094390e7d1cd27b96 Author: Luzian Hahn Date: Mon Jan 22 13:43:46 2024 +0100 fix: rely again on iso-8859-1 instead of utf8 the OpenGPT-X data seems to come with problematic chars, which cannot get edecoded via utf8. The former fix to use iso-8859-1 fixes this. However the issue probably lays actually with dataset conversions --- .github/workflows/tests.yml | 7 + .gitignore | 1 + CONTRIBUTING.md | 1 + README.md | 28 +- benchmarks/dataloader/README.md | 77 +++ benchmarks/dataloader/launch_benchmark.sh | 87 +++ ...ig_example_hf_meditron_7B_instruction.yaml | 199 ++++++ .../config_example_mem_map_dataset.yaml | 288 +++++---- config_files/config_lorem_ipsum.yaml | 336 ++++++---- ...2_gpt2_tokenized_num_samples_512_test.idx} | Bin ..._gpt2_tokenized_num_samples_512_test.pbin} | Bin ..._gpt2_tokenized_num_samples_512_train.idx} | Bin ...gpt2_tokenized_num_samples_512_train.pbin} | Bin .../{tokenizer.json => tokenizer_gpt2.json} | 0 docs/source/configuration.rst | 7 +- docs/source/memmap.rst | 8 +- docs/source/quickstart.rst | 8 +- .../{getting_started_example.md => README.md} | 44 +- examples/library_usage/README.md | 79 +++ .../library_usage/config_lorem_ipsum.yaml | 245 +++++++ examples/library_usage/main.py | 55 ++ examples/library_usage/run.sh | 3 + examples/pretraining_llama2/train.sh | 3 + pyproject.toml | 34 +- scripts/train.sh | 2 +- src/modalities/__main__.py | 343 +++++----- src/modalities/activation_checkpointing.py | 4 +- src/modalities/batch.py | 5 +- src/modalities/checkpointing/checkpointing.py | 6 +- .../checkpointing/checkpointing_execution.py | 33 +- .../checkpointing/checkpointing_factory.py | 36 -- src/modalities/config/component_factory.py | 144 +++++ src/modalities/config/config.py | 607 +++++++++--------- src/modalities/config/lookup_enum.py | 8 + src/modalities/config/lookup_types.py | 83 --- src/modalities/config/types.py | 5 - src/modalities/dataloader/create_index.py | 68 +- .../dataloader/create_packed_data.py | 260 ++++++-- src/modalities/dataloader/dataloader.py | 40 +- .../dataloader/dataloader_factory.py | 90 +-- src/modalities/dataloader/dataset.py | 138 ++-- src/modalities/dataloader/dataset_factory.py | 85 +++ .../dataloader/large_file_lines_reader.py | 27 +- .../open_gptx_dataset/open_gptx_dataset.py | 20 - src/modalities/dataloader/samplers.py | 4 +- src/modalities/evaluator.py | 28 +- src/modalities/exceptions.py | 2 +- .../logging_broker/message_broker.py | 5 +- src/modalities/logging_broker/publisher.py | 3 +- src/modalities/logging_broker/subscriber.py | 2 +- .../subscriber_impl/results_subscriber.py | 24 +- .../subscriber_impl/subscriber_factory.py | 76 +++ src/modalities/loss_functions.py | 80 +++ src/modalities/models/gpt2/collator.py | 11 +- src/modalities/models/gpt2/gpt2_model.py | 56 +- .../models/gpt2/preprocess_dataset.py | 18 +- src/modalities/models/huggingface/__init__.py | 1 + .../models/huggingface/huggingface_models.py | 84 +++ src/modalities/models/model.py | 4 +- src/modalities/models/model_factory.py | 44 ++ src/{ => modalities/optimizers}/__init__.py | 0 .../optimizers/optimizer_factory.py | 21 + src/modalities/registry/__init__.py | 0 src/modalities/registry/components.py | 157 +++++ src/modalities/registry/registry.py | 44 ++ src/modalities/resolver_register.py | 147 ----- src/modalities/running_env/cuda_env.py | 27 + src/modalities/running_env/env_utils.py | 6 +- .../running_env/fsdp/fsdp_auto_wrapper.py | 55 ++ .../running_env/fsdp/fsdp_running_env.py | 111 ---- src/modalities/running_env/running_env.py | 15 - src/modalities/test.py | 3 +- src/modalities/trainer.py | 44 +- src/modalities/util.py | 30 +- src/modalities/utils/generate_text.py | 27 +- tests/checkpointing/gpt2_config.yaml | 45 +- .../test_checkpoint_execution_functions.py | 15 +- .../test_fsdp_to_disc_checkpointing.py | 126 ++-- tests/config/__init__.py | 0 tests/config/components.py | 48 ++ tests/config/configs.py | 30 + tests/config/custom_components.py | 35 + tests/config/test_component_factory.py | 111 ++++ .../config_backward_reference.yaml | 28 + .../config_forward_reference.yaml | 27 + .../config_hierarchical_list_component.yaml | 15 + .../config_non_existing_reference.yaml | 17 + .../test_configs/config_single_component.yaml | 5 + tests/conftest.py | 20 +- tests/dataloader/test_dataloader.py | 37 +- .../test_large_file_lines_reader.py | 9 +- tests/dataloader/test_packed_dataset.py | 45 +- tests/test_evaluation.py | 52 -- tests/test_gym.py | 25 +- tests/test_loss_functions.py | 38 ++ tests/test_main.py | 16 +- 96 files changed, 3567 insertions(+), 1820 deletions(-) create mode 100644 benchmarks/dataloader/README.md create mode 100755 benchmarks/dataloader/launch_benchmark.sh create mode 100644 config_files/config_example_hf_meditron_7B_instruction.yaml rename data/sample_datasets/redpajama_v2/mem_map/{redpajama_v2_samples_512_test.idx => redpajama_v2_gpt2_tokenized_num_samples_512_test.idx} (100%) rename data/sample_datasets/redpajama_v2/mem_map/{redpajama_v2_samples_512_test.pbin => redpajama_v2_gpt2_tokenized_num_samples_512_test.pbin} (100%) rename data/sample_datasets/redpajama_v2/mem_map/{redpajama_v2_samples_512_train.idx => redpajama_v2_gpt2_tokenized_num_samples_512_train.idx} (100%) rename data/sample_datasets/redpajama_v2/mem_map/{redpajama_v2_samples_512_train.pbin => redpajama_v2_gpt2_tokenized_num_samples_512_train.pbin} (100%) rename data/tokenizer/{tokenizer.json => tokenizer_gpt2.json} (100%) rename examples/getting_started/{getting_started_example.md => README.md} (81%) create mode 100644 examples/library_usage/README.md create mode 100644 examples/library_usage/config_lorem_ipsum.yaml create mode 100644 examples/library_usage/main.py create mode 100644 examples/library_usage/run.sh create mode 100644 examples/pretraining_llama2/train.sh delete mode 100644 src/modalities/checkpointing/checkpointing_factory.py create mode 100644 src/modalities/config/component_factory.py create mode 100644 src/modalities/config/lookup_enum.py delete mode 100644 src/modalities/config/lookup_types.py delete mode 100644 src/modalities/config/types.py create mode 100644 src/modalities/dataloader/dataset_factory.py create mode 100644 src/modalities/logging_broker/subscriber_impl/subscriber_factory.py create mode 100644 src/modalities/models/huggingface/__init__.py create mode 100644 src/modalities/models/huggingface/huggingface_models.py create mode 100644 src/modalities/models/model_factory.py rename src/{ => modalities/optimizers}/__init__.py (100%) create mode 100644 src/modalities/optimizers/optimizer_factory.py create mode 100644 src/modalities/registry/__init__.py create mode 100644 src/modalities/registry/components.py create mode 100644 src/modalities/registry/registry.py delete mode 100644 src/modalities/resolver_register.py create mode 100644 src/modalities/running_env/cuda_env.py create mode 100644 src/modalities/running_env/fsdp/fsdp_auto_wrapper.py delete mode 100644 src/modalities/running_env/fsdp/fsdp_running_env.py delete mode 100644 src/modalities/running_env/running_env.py create mode 100644 tests/config/__init__.py create mode 100644 tests/config/components.py create mode 100644 tests/config/configs.py create mode 100644 tests/config/custom_components.py create mode 100644 tests/config/test_component_factory.py create mode 100644 tests/config/test_configs/config_backward_reference.yaml create mode 100644 tests/config/test_configs/config_forward_reference.yaml create mode 100644 tests/config/test_configs/config_hierarchical_list_component.yaml create mode 100644 tests/config/test_configs/config_non_existing_reference.yaml create mode 100644 tests/config/test_configs/config_single_component.yaml delete mode 100644 tests/test_evaluation.py create mode 100644 tests/test_loss_functions.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a97489e3..1d54a477 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,4 +27,11 @@ jobs: - name: Run tests run: | pytest + - name: Upload coverage data to coveralls.io + run: | + python -m pip install coveralls[toml] + coveralls --service=github + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + diff --git a/.gitignore b/.gitignore index 0bd78b69..ff12e477 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ htmlcov/ .cache nosetests.xml coverage.xml +coverage_html_report *.cover .hypothesis/ .pytest_cache/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 74106534..ed55664e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,6 +31,7 @@ pre-commit install --install-hooks - Make sure your code passes all the tests and pre-commit hooks. Use `pytest` from within the root of your local repository. +- For vscode users, disable pytest coverage in `settings.json` to enable pytest debugging: `"python.testing.pytestArgs": ["--no-cov"]` ## Commit Guidelines diff --git a/README.md b/README.md index ced552c2..0573a4ba 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,10 @@ # Modalities +[![Coverage Status](https://coveralls.io/repos/github/Modalities/modalities/badge.svg)](https://coveralls.io/github/Modalities/modalities) # Getting started For training and evaluation a model, feel free to checkout [this](https://github.com/Modalities/modalities/blob/main/examples/getting_started/getting_started_example.md) getting started tutorial, in which we train a small, 60M-parameter GPT model on a tiny subset of the Redpajama V2 dataset. -Also, see our WIki and API reference documentation: https://modalities.github.io/modalities/ +Also, see our Wiki and API reference documentation: https://modalities.github.io/modalities/ # Installation @@ -19,7 +20,7 @@ then, install the repository via pip install -e . ``` -If you want to contribute, have look at `CONTRIBUTING.md`. +If you want to contribute, have a look at `CONTRIBUTING.md`. @@ -56,12 +57,12 @@ Or, if you are a VsCode user, add this to your `launch.json`: # Pydantic and ClassResolver -The mechanismn introduced to instantiate classes via `type_hint` in the `config.yaml`, utilizes +The mechanism introduced to instantiate classes via `type_hint` in the `config.yaml`, utilizes 1) Omegaconf to load the config yaml file 2) Pydantic for the validation of the config 3) ClassResolver to instantiate the correct, concrete class of a class hierarchy. -Firstly, Omegaconf loads the config yaml file and resolves internal refrences such as `${subconfig.attribue}`. +Firstly, Omegaconf loads the config yaml file and resolves internal references such as `${subconfig.attribute}`. Then, Pydantic validates the whole config as is and checks that each of the sub-configs are `pydantic.BaseModel` classes. For configs, which allow different concrete classes to be instantiated by `ClassResolver`, the special member names `type_hint` and `config` are introduced. @@ -79,7 +80,7 @@ activation_kwargs={...} activation_resolver.make(type_hint, activation_kwargs), ``` -In our implmentation we go a step further, as both, +In our implementation we go a step further, as both, * a `type_hint` in a `BaseModel` config must be of type `modalities.config.lookup_types.LookupEnum` and * `config` is a union of allowed concrete configs of base type `BaseModel`. `config` hereby replaces `activation_kwargs` in the example above, and replaces it with pydantic-validated `BaseModel` configs. @@ -88,7 +89,8 @@ With this, a mapping between type hint strings needed for `class-resolver`, and ```python from enum import Enum -from pydantic import BaseModel, PositiveInt, PositiveFloat, conint, confloat +from typing import Annotated +from pydantic import BaseModel, PositiveInt, PositiveFloat, Field class LookupEnum(Enum): @classmethod @@ -101,8 +103,8 @@ class SchedulerTypes(LookupEnum): ConstantLR = torch.optim.lr_scheduler.ConstantLR class StepLRConfig(BaseModel): - step_size: conint(ge=1) - gamma: confloat(ge=0.0) + step_size: Annotated[int, Field(strict=True, ge=1)] + gamma: Annotated[float, Field(strict=True, ge=0.0)] class ConstantLRConfig(BaseModel): @@ -115,7 +117,7 @@ class SchedulerConfig(BaseModel): config: StepLRConfig | ConstantLRConfig ``` -To allow a user-friendly instantiation, all class resolvers are defined in the `ResolverRegistry` and `build_component_by_config` as convenience function is introduced. Dependecies can be passed-through with the `extra_kwargs` argument: +To allow a user-friendly instantiation, all class resolvers are defined in the `ResolverRegistry` and `build_component_by_config` as convenience function is introduced. Dependencies can be passed-through with the `extra_kwargs` argument: ```python resolvers = ResolverRegister(config=config) optimizer = ... # our example dependency @@ -187,20 +189,20 @@ Alternatively, directly use `src/modalities/__main__.py do_stuff --config_file_p The `MemMapDataset` requires an index file providing the necessary pointers into the raw data file. The `MemMapDataset` can create the index file lazily, however, it is advised to create it beforehand. This can be done by running ```sh -modalities create_memmap_index +modalities data create_raw_index ``` -The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities create_memmap_index --help`. +The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data create_raw_index --help`. ## Packed Dataset Generator The `PackedMemMapDatasetContinuous` and `PackedMemMapDatasetMegatron` require a packed data file. To create the data file, you first have to generate a `MemMapDataset` index file as described [above](#memmapdataset-index-generator). Assuming the index and raw data are located in the same directory, you can simply execute the following command: ```sh -modalities create_packed_data +modalities data pack_encoded_data ``` -The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities create_packed_data --help`. +The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data pack_encoded_data --help`. ### Packed Data Format diff --git a/benchmarks/dataloader/README.md b/benchmarks/dataloader/README.md new file mode 100644 index 00000000..850e8f06 --- /dev/null +++ b/benchmarks/dataloader/README.md @@ -0,0 +1,77 @@ +# Benchmarking of Dataset Implementations + +## Motivation +We want to include a storage efficient, fast and generic dataset implementation in this repository. +Previous work and ideas were based on MegatronLM and its dataset implementation. + +Unfortunately its usage is quite intransparent and causes regularly unexpected side effects. +Those problems are hard to trace, as we are not the original authors of the code. + +Therefore we want to provide an own implementation, which comes with all the above mentioned benefits. +Most importantly, it should be at least as fast as MegatronLM's implementation. + + +## Benchmark Overview + +We want to evaluate multiple aspects of the dataset implementations: +* preparation speed - All datasets need to do some initial steps like tokenization and indexing. +* initialization speed - When firing up a respective `Dataset` object inside the code. +* iteration speed - When accessing elements (in a random order) in the respective datasets + + +## Used Example Dataset + +The experiments were conducted on a small sample of openwebtext. The data is provided in `.jsonl`-format. +The relevant data included can be found under `"text"` and is obviously text-only. +Each dataset with X samples refers to the first X lines in the full openwebtext data, + as it can be obtained from huggingface. + + +## Experimental Setup + +We relied on the functions provided in `launch_benchmark.sh`. One can reproduce those by calling e.g. + +```shell +. launch_benchmark.sh + +INPUT_DIR= + +echo "MegatronLM:" +measure_megatronLM_iteration +echo "Modalities:" +measure_modalities_iteration +``` + +> For launching the preparation of MegatronLM's dataset, refer to: +> https://github.com/OpenGPTX/opengptx_data/tree/docs/modalities-vs-megatronlm-dl and look at the `launch_benchmark.sh` +> script. + +#### Glossary + +* **preparation:** refers here to the task of turning raw data (e.g. jsonl encoded text) into a binary file, + which is loadable later for training. + For MegatronLM this means tokenizing and packing everything according to their defined format. + For Modalities it means, indexing the raw data and packing it afterwards as token-ids. +* **initialization:** refers to the process of initializing a python object, + which represents the respective dataset (mostly represented via the `torch.Dataset`-interface) +* **iteration:** refers to process of iterating over the respective datasets - once sequentially and once shuffled. + +## Results + + +| Evaluation Aspect | Implementation | Required Time | # Samples in Data | +|----------------------|----------------|:------------------:|-------------------| +| preparation speed | MegatronLM | `0 min 16.965 sec` | `20000(OWT)` | +| preparation speed | Modalities | `0 min 13.904 sec` | `20000(OWT)` | +| preparation speed | MegatronLM | `2 min 11.856 sec` | `200000(OWT)` | +| preparation speed | Modalities | `0 min 38.738 sec` | `200000(OWT)` | +| initialization speed | MegatronLM | `19.3 msec` | `20000(OWT)` | +| initialization speed | Modalities | `5.85 msec` | `20000(OWT)` | +| initialization speed | MegatronLM | `180 msec ` | `200000(OWT)` | +| initialization speed | Modalities | `58 msec` | `200000(OWT)` | +| iteration speed | MegatronLM | `52.4 msec` | `20000(OWT)` | +| iteration speed | Modalities | `66.8 msec` | `20000(OWT)` | +| iteration speed | MegatronLM | `426 msec ` | `200000(OWT)` | +| iteration speed | Modalities | `545 msec` | `200000(OWT)` | + + diff --git a/benchmarks/dataloader/launch_benchmark.sh b/benchmarks/dataloader/launch_benchmark.sh new file mode 100755 index 00000000..c4e9f69d --- /dev/null +++ b/benchmarks/dataloader/launch_benchmark.sh @@ -0,0 +1,87 @@ +#!/bin/bash + + + +INPUT_DIR="/tmp/i-do-not-exist.jsonl" + + +measure_modalities_preparation() { + time ( + set -e + test -f $INPUT_DIR + rm -f ${INPUT_DIR/.jsonl/.idx} + modalities data create_raw_index $INPUT_DIR &> /dev/null + echo "finished memmap index creation" + rm -f ${INPUT_DIR/.jsonl/.pbin} + modalities data pack_encoded_data $INPUT_DIR &> /dev/null + echo "finished memmap packing" + ) +} + + +measure_modalities_initialization() { + input_file=${INPUT_DIR/.jsonl/.pbin} + python -m timeit -n 50 -r 5 -s " +import sys, io +null_device = io.StringIO() +from modalities.dataloader.dataset import PackedMemMapDatasetMegatron +from pathlib import Path +p = Path(\"${input_file}\") + " -- " +sys.stdout = null_device # deactivate stdout to avoid getting spammed +PackedMemMapDatasetMegatron(raw_data_path=p, block_size=1024, sample_key=\"sample\") +sys.stdout = sys.__stdout__ # reactivate stdout for timeit +" +} + +measure_megatronLM_initialization() { + input_file="${INPUT_DIR/.jsonl/.megLM.bin_text_document}" + python -m timeit -n 50 -r 5 -s " +import sys, io +null_device = io.StringIO() +from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDataset +p = \"${input_file}\" + " -- " +sys.stdout = null_device # deactivate stdout to avoid getting spammed +MMapIndexedDataset(p) +sys.stdout = sys.__stdout__ # reactivate stdout for timeit +" +} + +measure_modalities_iteration() { + input_file=${INPUT_DIR/.jsonl/.pbin} + python -m timeit -n 5 -r 3 -s " +import random, sys, io +null_device = io.StringIO() +from modalities.dataloader.dataset import PackedMemMapDatasetMegatron +from pathlib import Path +p = Path(\"${input_file}\") +sys.stdout = null_device # deactivate stdout to avoid getting spammed +dataset = PackedMemMapDatasetMegatron(raw_data_path=p, block_size=1024, sample_key=\"sample\") +random_indices = random.sample(range(len(dataset)), len(dataset)) +sys.stdout = sys.__stdout__ # reactivate stdout for timeit + " -- " +list(dataset) # sequential access +for i in random_indices: + dataset[i] +" +} + + +measure_megatronLM_iteration() { + input_file="${INPUT_DIR/.jsonl/.megLM.bin_text_document}" + python -m timeit -n 5 -r 3 -s " +import random, sys, io +null_device = io.StringIO() +from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDataset +p = \"${input_file}\" +sys.stdout = null_device # deactivate stdout to avoid getting spammed +dataset = MMapIndexedDataset(p) +random_indices = random.sample(range(len(dataset)), len(dataset)) +sys.stdout = sys.__stdout__ # reactivate stdout for timeit + " -- " +list(dataset) # sequential access +for i in random_indices: + dataset[i] +" +} \ No newline at end of file diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml new file mode 100644 index 00000000..590525dc --- /dev/null +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -0,0 +1,199 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + training: + callback_interval_in_samples: 2048 + global_num_training_samples: 2048 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 4096 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_1050391.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_1024.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [val_dataloader] + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [LlamaDecoderLayer] + + +model: + component_key: model + variant_key: huggingface_pretrained_model + config: + model_type: AutoModelForCausalLM + model_name: epfl-llm/meditron-7b + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${settings.referencing_keys.prediction_key} + huggingface_prediction_subscription_key: ${settings.referencing_keys.prediction_key} + kwargs: + cache_dir: /raid/s3/opengptx/max_lue/hf_cache/ + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [LlamaDecoderLayer] + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +# scheduler: +# type_hint: StepLR +# config: +# step_size: 1 +# gamma: 0.1 + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: [] + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." \ No newline at end of file diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index 005f2fca..62498856 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -1,114 +1,141 @@ -modalities_setup: - run_mode: FROM_SCRATCH - settings: +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + training: + callback_interval_in_samples: 32768 + global_num_training_samples: 2048 global_num_seen_samples: 0 + do_apply_activation_checkpointing: false + gradient_acc_steps: 1 + local_train_micro_batch_size: 16 + sequence_length: 4096 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints -wandb: - project_name: modalities - mode: ONLINE -data: - sample_key: "input_ids" - target_key: "target_ids" - sequence_len: 1024 - train_dataloader: - type_hint: LLMDataLoader - config: - dataloader_tag: "train" - num_workers: 2 - pin_memory: true - shuffle: false - batch_sampler: - type_hint: BatchSampler - config: - batch_size: 8 # per rank - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - dataset: - type_hint: PackedMemMapDatasetContinuous - config: - raw_data_path: ./data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_1024_train.pbin - block_size: ${data.sequence_len} - sample_key: ${data.sample_key} - collate_fn: - type_hint: GPT2LLMCollator - config: - sample_key: ${data.sample_key} - target_key: ${data.target_key} - eval_dataloaders: - - type_hint: LLMDataLoader +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/LLMgym/data/redpyjama_v2_default_DE_num_docs_16777216.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default config: - dataloader_tag: "val" - num_workers: 2 - pin_memory: true - shuffle: false - batch_sampler: - type_hint: BatchSampler - config: - batch_size: 8 # per rank - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: false - dataset: - type_hint: PackedMemMapDatasetContinuous - config: - raw_data_path: ./data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_1024_test.pbin - block_size: ${data.sequence_len} - sample_key: ${data.sample_key} - collate_fn: - type_hint: GPT2LLMCollator - config: - sample_key: ${data.sample_key} - target_key: ${data.target_key} + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE -training: - process_group_backend: "nccl" - global_num_training_samples: 2048 - callback_interval_in_samples: 256 - local_rank: ${oc.env:LOCAL_RANK} - global_rank: ${oc.env:RANK} - world_size: ${oc.env:WORLD_SIZE} - main_rank: 0 - local_train_micro_batch_size: ${data.train_dataloader.config.batch_sampler.config.batch_size} - global_num_seen_samples: ${modalities_setup.settings.global_num_seen_samples} - gradient_acc_step: 1 - do_apply_activation_checkpointing: false +val_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/LLMgym/data/redpyjama_v2_default_DE_num_docs_1024.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} -checkpointing: - checkpointing_strategy: - type_hint: SaveKMostRecentCheckpointsStrategy - config: - k: -1 # -1 to save all checkpoints - checkpointing_execution: - type_hint: FSDPToDiscCheckpointing - config: - checkpoint_path: ./data/checkpoints - global_rank: ${oc.env:RANK} +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE -running_env: - type_hint: FSDPRunningEnv +checkpointing: + component_key: checkpointing + variant_key: default config: - process_group_backend: ${training.process_group_backend} - local_rank: ${oc.env:LOCAL_RANK} - mixed_precision_settings: BF_16 - sharding_strategy: FULL_SHARD - auto_wrap_policy: TRANSFORMER_AUTO_WRAP_POLICY + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] model: - type_hint: GPT2LLM + component_key: model + variant_key: gpt2 config: - sample_key: ${data.sample_key} - prediction_key: "logits" - block_size: ${data.sequence_len} + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${settings.referencing_keys.prediction_key} + block_size: ${settings.training.sequence_length} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 12 n_head: 12 @@ -125,20 +152,61 @@ model: mean: 0.0 std: 0.02 -scheduler: - type_hint: StepLR +wrapped_model: + component_key: model + variant_key: fsdp_wrapped config: - step_size: 1 - gamma: 0.1 + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] -optimizer: - type_hint: AdamW +# scheduler: +# type_hint: StepLR +# config: +# step_size: 1 +# gamma: 0.1 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +optimizer: + component_key: optimizer + variant_key: adam_w config: lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE -loss: - type_hint: CLMCrossEntropyLoss + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb config: - target_key: ${data.target_key} - prediction_key: ${model.config.prediction_key} \ No newline at end of file + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: ONLINE + experiment_id: ${settings.experiment_id} + directory: "." \ No newline at end of file diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index 9ac7c93f..c9f01291 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -1,136 +1,194 @@ -modalities_setup: - run_mode: FROM_SCRATCH - settings: +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 6 + global_num_training_samples: 12 global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 3 + sequence_length: 256 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: data/tokenizer/tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: mem_map_dataset + config: + raw_data_path: data/lorem_ipsum.jsonl + index_path: data/lorem_ipsum.idx + block_size: ${settings.training.sequence_length} + jq_pattern: ".text" + sample_key: ${settings.referencing_keys.sample_key} + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE -wandb: - project_name: modalities - mode: OFFLINE +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE -data: - sample_key: "input_ids" - target_key: "target_ids" - sequence_len: 128 - train_dataloader: - type_hint: LLMDataLoader - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "train" - dataset: - type_hint: MemMapDataset - config: - raw_data_path: data/lorem_ipsum.jsonl - index_path: data/lorem_ipsum.idx - block_size: ${data.sequence_len} - jq_pattern: ".text" - sample_key: ${data.sample_key} - tokenizer: - type_hint: GPT2TokenizerFast - config: - tokenizer_file: data/tokenizer/tokenizer.json - batch_sampler: - type_hint: BatchSampler - config: - batch_size: 3 - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - collate_fn: - type_hint: GPT2LLMCollator - config: - sample_key: ${data.sample_key} - target_key: ${data.target_key} - eval_dataloaders: - - type_hint: LLMDataLoader +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "val" - dataset: ${data.train_dataloader.config.dataset} - batch_sampler: - type_hint: BatchSampler + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler config: - batch_size: 3 - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - collate_fn: ${data.train_dataloader.config.collate_fn} - - type_hint: LLMDataLoader + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "test" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "test" - dataset: ${data.train_dataloader.config.dataset} - batch_sampler: - type_hint: BatchSampler + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler config: - batch_size: 3 - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - collate_fn: ${data.train_dataloader.config.collate_fn} + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE -training: - process_group_backend: "nccl" - global_num_training_samples: 12 - callback_interval_in_samples: 6 - local_rank: ${oc.env:LOCAL_RANK} - global_rank: ${oc.env:RANK} - world_size: ${oc.env:WORLD_SIZE} - main_rank: 0 - local_train_micro_batch_size: ${data.train_dataloader.config.batch_sampler.config.batch_size} - global_num_seen_samples: ${modalities_setup.settings.global_num_seen_samples} - gradient_acc_step: 1 - do_apply_activation_checkpointing: True +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + - instance_key: test_dataloader + pass_type: BY_REFERENCE checkpointing: - checkpointing_strategy: - type_hint: SaveKMostRecentCheckpointsStrategy - config: - k: -1 # -1 to save all checkpoints - checkpointing_execution: - type_hint: FSDPToDiscCheckpointing - config: - checkpoint_path: data/checkpoints - global_rank: ${oc.env:RANK} + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] -loss: - type_hint: CLMCrossEntropyLoss +# resolving class types via different enums sucks... +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss config: - target_key: ${data.target_key} - prediction_key: ${model.config.prediction_key} + target_key: target_ids + prediction_key: logits -running_env: - type_hint: FSDPRunningEnv +wrapped_model: + component_key: model + variant_key: fsdp_wrapped config: - process_group_backend: ${training.process_group_backend} - local_rank: ${oc.env:LOCAL_RANK} - mixed_precision_settings: FP_16 + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 sharding_strategy: FULL_SHARD - auto_wrap_policy: TRANSFORMER_AUTO_WRAP_POLICY + block_names: [GPT2Block] model: - type_hint: GPT2LLM + component_key: model + variant_key: gpt2 config: - sample_key: ${data.sample_key} - prediction_key: "logits" - block_size: ${data.sequence_len} + sample_key: "input_ids" # TODO reference this + prediction_key: "logits" # TODO reference this + block_size: 256 # TODO reference this (same as sequence length) vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 2 n_head: 4 @@ -139,7 +197,7 @@ model: dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster attention: - attention_type: pytorch_flash_attention + attention_type: default_attention # pytorch_flash_attention scaling_factor: 3 activation: gelu epsilon: 1e-5 @@ -147,13 +205,45 @@ model: mean: 0.0 std: 0.02 -scheduler: - type_hint: StepLR - config: - step_size: 1 - gamma: 0.1 +# scheduler: +# type_hint: StepLR +# config: +# step_size: 1 +# gamma: 0.1 -optimizer: - type_hint: AdamW +optimizer: + component_key: optimizer + variant_key: adam_w config: lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +# message subscriber + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." + \ No newline at end of file diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.idx b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.idx similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.idx rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.idx diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.pbin b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.pbin similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.pbin rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.pbin diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.idx b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.idx similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.idx rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.idx diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.pbin b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.pbin similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.pbin rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.pbin diff --git a/data/tokenizer/tokenizer.json b/data/tokenizer/tokenizer_gpt2.json similarity index 100% rename from data/tokenizer/tokenizer.json rename to data/tokenizer/tokenizer_gpt2.json diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index e51335fa..85545671 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -47,7 +47,8 @@ With this, a mapping between type hint strings needed for `class-resolver`, and .. code-block:: python from enum import Enum - from pydantic import BaseModel, PositiveInt, PositiveFloat, conint, confloat + from typing import Annotated + from pydantic import BaseModel, PositiveInt, PositiveFloat, Field class LookupEnum(Enum): @classmethod @@ -60,8 +61,8 @@ With this, a mapping between type hint strings needed for `class-resolver`, and ConstantLR = torch.optim.lr_scheduler.ConstantLR class StepLRConfig(BaseModel): - step_size: conint(ge=1) - gamma: confloat(ge=0.0) + step_size: Annotated[int, Field(strict=True, ge=1)] + gamma: Annotated[float, Field(strict=True, ge=0.0)] class ConstantLRConfig(BaseModel): diff --git a/docs/source/memmap.rst b/docs/source/memmap.rst index 22793c08..84326fc4 100644 --- a/docs/source/memmap.rst +++ b/docs/source/memmap.rst @@ -14,9 +14,9 @@ The :python:`MemMapDataset` requires an index file providing the necessary point .. code-block:: bash - modalities create_memmap_index + modalities data create_raw_index -The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities create_memmap_index --help`. +The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities data create_raw_index --help`. Packed Dataset Generator -------------------------------------------------------------------------------- @@ -25,9 +25,9 @@ The :python:`PackedMemMapDatasetContinuous` and :python:`PackedMemMapDatasetMega .. code-block:: bash - modalities create_packed_data + modalities data pack_encoded_data -The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities create_packed_data --help`. +The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities data pack_encoded_data --help`. Packed Data Format ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index d39a78a8..a8a81900 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -12,20 +12,20 @@ To start a training you need to create memmap dataset out of a jsonl file first, .. code-block:: bash # Create memmap dataset from jsonl file. - modalities create_memmap_index + modalities data create_raw_index # Create packed dataset. - modalities create_packed_data + modalities data pack_encoded_data For example, using the lorem ipsum example: .. code-block:: bash # Create memmap dataset from jsonl file. - modalities create_memmap_index data/lorem_ipsum.jsonl + modalities data create_raw_index data/lorem_ipsum.jsonl # Create packed dataset. - modalities create_packed_data data/lorem_ipsum.jsonl + modalities data pack_encoded_data data/lorem_ipsum.jsonl Training ---------------------------------------------------- diff --git a/examples/getting_started/getting_started_example.md b/examples/getting_started/README.md similarity index 81% rename from examples/getting_started/getting_started_example.md rename to examples/getting_started/README.md index a5ffdd74..53f49371 100644 --- a/examples/getting_started/getting_started_example.md +++ b/examples/getting_started/README.md @@ -11,7 +11,7 @@ As a reference, this example has the following folder structure. Folders in <> w ├── example_config.yaml ├── data │ ├── mem_map - │ │ └ + │ ├── │ └── raw │ ├── redpajama_v2_samples_512_test.jsonl │ └── redpajama_v2_samples_512_train.jsonl @@ -23,8 +23,7 @@ As a reference, this example has the following folder structure. Folders in <> w ``` ## 1. Preprocessing -A single line of the Redpajama V2 JSONL file has the structure denoted below. Since we are not interested in the meta data and quality signals for this minimal example, we consider the `raw_content` from each line without any filtering. -for model training. +A single line of the Redpajama V2 JSONL file has the structure denoted below. Since we are not interested in the meta data and quality signals for this minimal example, we consider the `raw_content` from each line without any filtering for model training. ```json { "raw_content":"Archivio Tag: 25 aprile\nSupermercati aperti 25 aprile 2019: centri commerciali e negozi a Roma, Milano, Napoli e Torino\nNell\u2019articolo odierno troverete tutte le informazioni utili su quali saranno i supermercati e le attivit\u00e0 commerciali che resteranno aperti in occasione...\nAuguri di Buon 25 Aprile 2017: frasi e pensieri originali sulla Festa della Liberazione", @@ -42,29 +41,29 @@ Firstly, we create the dataset index via cd modalities/examples/getting_started/ # train split -modalities create_memmap_index --index_path data/mem_map/redpajama_v2_samples_512_train.idx \ +modalities data create_raw_index --index_path data/mem_map/redpajama_v2_samples_512_train.idx \ data/raw/redpajama_v2_samples_512_train.jsonl # test split -modalities create_memmap_index --index_path data/mem_map/redpajama_v2_samples_512_test.idx \ +modalities data create_raw_index --index_path data/mem_map/redpajama_v2_samples_512_test.idx \ data/raw/redpajama_v2_samples_512_test.jsonl ``` -In this step, we read the JSON file as a binary file, iterate over all characters und build up the sample index (char-wisestart and end position for each JSON sample) -as determined by the `\n` character positions. The sample index is stored in the specified `index_path`. Internally, the `create_memmap_index` command -instantiates and calls the the [IndexGenerator](https://github.com/Modalities/modalities/blob/main/src/modalities/dataloader/create_index.py#L14). +In this step, we read the JSON file as a binary file, iterate over all characters and build up the sample index (char-wise start and end position for each JSON sample) +as determined by the `\n` character positions. The sample index is stored in the specified `index_path`. Internally, the `create_raw_index` command +instantiates and calls the [IndexGenerator](https://github.com/Modalities/modalities/blob/main/src/modalities/dataloader/create_index.py#L14). After having determined the index, we create the packed dataset as described below by leveraging the tokenizer, jsonl file and the created index. ```sh # train split -modalities create_packed_data --jq_pattern .raw_content \ +modalities data pack_encoded_data --jq_pattern .raw_content \ --index_path data/mem_map/redpajama_v2_samples_512_train.idx \ --dst_path data/mem_map/redpajama_v2_samples_512_train.pbin \ --tokenizer_file tokenizer/tokenizer.json \ data/raw/redpajama_v2_samples_512_train.jsonl # test split -modalities create_packed_data --jq_pattern .raw_content \ +modalities data pack_encoded_data --jq_pattern .raw_content \ --index_path data/mem_map/redpajama_v2_samples_512_test.idx \ --dst_path data/mem_map/redpajama_v2_samples_512_test.pbin \ --tokenizer_file tokenizer/tokenizer.json \ @@ -84,15 +83,21 @@ Technically, packed datasets are defined a self-contained format that stores the **Packed MemMap File Format** ``` -|--8-BYTES-HEADER--|-------------------DATA-SEGMENT-------------------|----INDEX-SEGMENT----| +|--HEADER--|-------------------DATA-SEGMENT-------------------|----INDEX-SEGMENT----| -8 bytes header: +header: =============== -specifies the size of the data segment in bytes. Since the header size is fixed to 8 bytes, -the start and end position of each segment (i.e, header, data, index) is specified. Therefore, the theoretical maximum size of the data segment -is 2^64 bytes = 18,446 peta bytes or 4600e+15 tokens or 4.6 quintillion tokens, given that a token has 4 bytes. - +Contains two elements: +* Specifies the size of the data segment in bytes. Since the header size is fixed to 8 bytes, + the start and end position of each segment (i.e, header, data, index) is specified. + Therefore, the theoretical maximum size of the data segment + is 2^64 bytes = 18,446 peta bytes or 4600e+15 tokens or 4.6 quintillion tokens, given that a token has 4 bytes. +* The size of a each represented single token in the data segment in bytes. + This values is inferred from the source data of this `.pbin` + and depends solely on the tokenizer's vocabulary used for encoding. + A 4-byte integer is used for this. +Therefore the header is always 8+4=12 bytes long. Data segment: ============= @@ -115,7 +120,7 @@ first and then divides it into chunks of size context-length. In modalities, we describe the entire training and evaluation setup (i.e., components such das model, trainer, evaluator, dataloder etc.) within a single config file. Not only does this increase reproducibility but also allows for having the entire training runs under version control. -The example config file for this experiment can be found in `examples/mem_map_redpajama_gpt/config_example_mem_map_dataset.yaml`. +The example config file for this experiment can be found in `examples/getting_started/example_config.yaml`. ## 2. Training @@ -151,8 +156,8 @@ The command can be broken down into the following parts: 7. **`run`**: - Command argument for the `modalities` executable to initiate the training. -8. **`--config_file_path config_example_mem_map_dataset.yaml`**: - - Specifies the path to the configuration file. The file `config_example_mem_map_dataset.yaml` contains mentinoed configuratino of the components, including dataset and model configurations, training parameters, etc. +8. **`--config_file_path example_config.yaml`**: + - Specifies the path to the configuration file. The file `example_config.yaml` contains the configuration of the components, including dataset and model configurations, training parameters, etc. Already during the training, the checkpoints can be found locally in `checkpoints/` and the loss and metric developments can be inspected online in [Weights&Biases](https://wandb.ai/). @@ -171,5 +176,4 @@ which opens an interactive chatting CMD interface. ``` enter prompt> Once upon a time, there was ... - ``` \ No newline at end of file diff --git a/examples/library_usage/README.md b/examples/library_usage/README.md new file mode 100644 index 00000000..89898c53 --- /dev/null +++ b/examples/library_usage/README.md @@ -0,0 +1,79 @@ +# Running Modalities like a package + +Modalities can be used in a library fashion by installing the package via `pip`, as described in the [README](https://github.com/Modalities/modalities?tab=readme-ov-file#installation). The framework allows for the addition of custom components to the registry at runtime without necessitating any code changes to modalities. This functionality is achieved in Modalities with the introduction of a component registry, containing all the internal components (e.g., Dataloader, Loss function etc.). To support the addition of custom components (e.g., new model architectures) at runtime, Modalities exposes a function endpoint adding custom components to the internal registry. + +A typical use case for running Modalities in package-like fashion would be to have a custom model implemented in a repository parallel to modalities. To train the model, we would register the model class and its config class within Modalities' registry and additionally provide the typical training config (see [here](https://github.com/Modalities/modalities/blob/main/examples/getting_started/example_config.yaml) for an example) that also references the new model. Since modalities is aware of the model and config class, the model can be built from the config YAML file and used for training. + +## Concrete Example + +Given the explanation above, we now provide a minimal dummy example of the process of implementing, registering and instantiating a custom component via the example of a custom collate function. +The full example code can be found [here](https://github.com/Modalities/modalities/tree/hierarchical_instantiation/examples/library_usage). + +The code for the custom collate function, its config and registering is implemented in +[main.py](https://github.com/Modalities/modalities/blob/hierarchical_instantiation/examples/library_usage/main.py). Firstly, the script implements the custom collate function by first defining the config that parameterizes the collate function. Here, we took the two attributes from the original [GPT2LLMCollateFnConfig]() and added the custom field `custom_attribute`. + +```python + class CustomGPT2LLMCollateFnConfig(BaseModel): + sample_key: str + target_key: str + custom_attribute: str +``` + +The collate function implements the `CollateFnIF` interface. Its constructor expects the attributes from the previously defined `CustomGPT2LLMCollateFnConfig`. Since this is only a minimal example to demonstrate the registering of custom components, we just print the custom attribute without adding any senseful functionality. + +```python +class CustomGPT2LLMCollateFn(CollateFnIF): + def __init__(self, sample_key: str, target_key: str, custom_attribute: str): + self.sample_key = sample_key + self.target_key = target_key + self.custom_attribute = custom_attribute + + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) + samples = {self.sample_key: sample_tensor[:, :-1]} + targets = {self.target_key: sample_tensor[:, 1:]} + + print(f"Custom attribute: {self.custom_attribute}") + + return DatasetBatch(targets=targets, samples=samples) +``` + +Given `CustomGPT2LLMCollateFnConfig` and `CustomGPT2LLMCollateFnConfig`, we register the new component via `add_custom_component(...)` by providing the respective component key and variant key together with the two previously defined classes. Note that even though the `component_key` and `variant_key` are in principle arbitrary, it is good practice to follow the patterns used for the internal components, as defined in [components.py](https://github.com/Modalities/modalities/blob/hierarchical_instantiation/src/modalities/registry/components.py#L64). + +```python +def main(): + # load and parse the config file + config_file_path = Path("config_lorem_ipsum.yaml") + config_dict = load_app_config_dict(config_file_path) + + # instantiate the Main entrypoint of modalities by passing in the config + modalities_main = Main(config_dict=config_dict) + + # add the custom component to modalities + modalities_main.add_custom_component( + component_key="collate_fn", + variant_key="custom_gpt_2_llm_collator", + custom_component=CustomGPT2LLMCollateFn, + custom_config=CustomGPT2LLMCollateFnConfig, + ) + # run the experiment + modalities_main.run() +``` + +Lastly, we add the `collate_fn` to the [example YAML config](https://github.com/Modalities/modalities/blob/hierarchical_instantiation/examples/library_usage/config_lorem_ipsum.yaml) with the the new collator. +```yaml +collate_fn: + component_key: collate_fn + variant_key: custom_gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + custom_attribute: "custom_value" +``` + +Given the changes above, we are now ready to run the training by executing the following bash command in the example directory. +```sh +CUDA_VISIBLE_DEVICES=0,1 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 2 main.py +``` + + diff --git a/examples/library_usage/config_lorem_ipsum.yaml b/examples/library_usage/config_lorem_ipsum.yaml new file mode 100644 index 00000000..f41a2507 --- /dev/null +++ b/examples/library_usage/config_lorem_ipsum.yaml @@ -0,0 +1,245 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 6 + global_num_training_samples: 12 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 3 + sequence_length: 256 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: custom_gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + custom_attribute: "custom_value" + +train_dataset: + component_key: dataset + variant_key: mem_map_dataset + config: + raw_data_path: ../../data/lorem_ipsum.jsonl + index_path: ../../data/lorem_ipsum.idx + block_size: ${settings.training.sequence_length} + jq_pattern: ".text" + sample_key: ${settings.referencing_keys.sample_key} + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "test" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] + +# resolving class types via different enums sucks... +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: target_ids + prediction_key: logits + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] + +model: + component_key: model + variant_key: gpt2 + config: + sample_key: "input_ids" # TODO reference this + prediction_key: "logits" # TODO reference this + block_size: 256 # TODO reference this (same as sequence length) + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention: + attention_type: default_attention # pytorch_flash_attention + scaling_factor: 3 + activation: gelu + epsilon: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +# message subscriber + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." + \ No newline at end of file diff --git a/examples/library_usage/main.py b/examples/library_usage/main.py new file mode 100644 index 00000000..cb03eb63 --- /dev/null +++ b/examples/library_usage/main.py @@ -0,0 +1,55 @@ +from pathlib import Path +from typing import Dict, List + +import torch +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.batch import DatasetBatch +from modalities.config.config import load_app_config_dict +from modalities.models.gpt2.collator import CollateFnIF + + +class CustomGPT2LLMCollateFnConfig(BaseModel): + sample_key: str + target_key: str + custom_attribute: str + + +class CustomGPT2LLMCollateFn(CollateFnIF): + def __init__(self, sample_key: str, target_key: str, custom_attribute: str): + self.sample_key = sample_key + self.target_key = target_key + self.custom_attribute = custom_attribute + + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) + samples = {self.sample_key: sample_tensor[:, :-1]} + targets = {self.target_key: sample_tensor[:, 1:]} + + print(f"Custom attribute: {self.custom_attribute}") + + return DatasetBatch(targets=targets, samples=samples) + + +def main(): + # load and parse the config file + config_file_path = Path("config_lorem_ipsum.yaml") + config_dict = load_app_config_dict(config_file_path) + + # instantiate the Main entrypoint of modalities by passing in the config + modalities_main = Main(config_dict=config_dict, config_path=config_file_path) + + # add the custom component to modalities + modalities_main.add_custom_component( + component_key="collate_fn", + variant_key="custom_gpt_2_llm_collator", + custom_component=CustomGPT2LLMCollateFn, + custom_config=CustomGPT2LLMCollateFnConfig, + ) + # run the experiment + modalities_main.run() + + +if __name__ == "__main__": + main() diff --git a/examples/library_usage/run.sh b/examples/library_usage/run.sh new file mode 100644 index 00000000..89effdc7 --- /dev/null +++ b/examples/library_usage/run.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +CUDA_VISIBLE_DEVICES=0,1 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 2 main.py \ No newline at end of file diff --git a/examples/pretraining_llama2/train.sh b/examples/pretraining_llama2/train.sh new file mode 100644 index 00000000..c2393e4e --- /dev/null +++ b/examples/pretraining_llama2/train.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 6 $(which modalities) run --config_file_path config_example_hf_meditron_7B_instruction.yaml \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index caa789cd..1e3fa5c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ [project.optional-dependencies] linting = ["pre-commit"] -tests = ["pytest"] +tests = ["pytest", "pytest-cov"] [project.scripts] modalities = "modalities.__main__:main" @@ -45,3 +45,35 @@ line_length = 120 [tool.ruff] line-length = 120 + +[tool.pytest.ini_options] +addopts = "--cov=src --cov-report term --cov-report html" + +[tool.coverage.run] +branch = true +omit = ["*/src/modalities/dataloader/open_gptx_dataset/*"] + +[tool.coverage.report] +# Regexes for lines to exclude from consideration +exclude_also = [ + # Don't complain about missing debug-only code: + "def __repr__", + "if self\\.debug", + + # Don't complain if tests don't hit defensive assertion code: + "raise AssertionError", + "raise NotImplementedError", + + # Don't complain if non-runnable code isn't run: + "if 0:", + "if __name__ == .__main__.:", + + # Don't complain about abstract methods, they aren't run: + "@(abc\\.)?abstractmethod", + ] + + +ignore_errors = true + +[tool.coverage.html] +directory = "coverage_html_report" \ No newline at end of file diff --git a/scripts/train.sh b/scripts/train.sh index 1f110c27..142a0e33 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,3 +1,3 @@ #!/bin/sh -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 8 $(which modalities) run --config_file_path /raid/s3/opengptx/max_lue/modalities/config_files/config_example_mem_map_dataset.yaml \ No newline at end of file +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 6 $(which modalities) run --config_file_path ../config_files/config_example_mem_map_dataset.yaml \ No newline at end of file diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 8f709130..1b712250 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -1,42 +1,32 @@ #!/usr/bin/env python import logging +import os +import shutil from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, Tuple import click import click_pathlib -import torch -import torch.nn as nn -from omegaconf import OmegaConf -from torch.optim import Optimizer from modalities.activation_checkpointing import apply_activation_checkpointing_inplace from modalities.batch import EvaluationResultBatch -from modalities.checkpointing.checkpointing import Checkpointing, CheckpointingIF -from modalities.checkpointing.checkpointing_factory import CheckpointingFactory -from modalities.config.config import AppConfig, ModalitiesSetupConfig, RunMode -from modalities.config.lookup_types import TokenizerTypes +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ComponentsModel, ProcessGroupBackendType, TokenizerTypes, load_app_config_dict from modalities.dataloader.create_index import IndexGenerator -from modalities.dataloader.create_packed_data import PackedDataGenerator -from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.dataloader_factory import DataloaderFactory +from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.evaluator import Evaluator from modalities.gym import Gym from modalities.logging_broker.message_broker import MessageBroker from modalities.logging_broker.messages import BatchProgressUpdate, MessageTypes from modalities.logging_broker.publisher import MessagePublisher -from modalities.logging_broker.subscriber_impl.batch_progress_subscriber import ( - DummyProgressSubscriber, - RichProgressSubscriber, -) -from modalities.logging_broker.subscriber_impl.results_subscriber import WandBEvaluationResultSubscriber -from modalities.loss_functions import Loss -from modalities.resolver_register import ResolverRegister -from modalities.running_env.fsdp.fsdp_running_env import RunningEnv +from modalities.logging_broker.subscriber import MessageSubscriberIF +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from modalities.running_env.cuda_env import CudaEnv from modalities.trainer import Trainer -from modalities.util import compute_number_of_trainable_parameters, get_date_of_run +from modalities.util import compute_number_of_trainable_parameters, get_callback_interval_in_batches_per_rank from modalities.utils.generate_text import main as generate_text_main @@ -54,8 +44,7 @@ def main() -> None: ) def entry_point_run_modalities(config_file_path: Path): config_dict = load_app_config_dict(config_file_path) - config = AppConfig.model_validate(config_dict) - main = Main(config) + main = Main(config_dict, config_file_path) main.run() @@ -83,7 +72,15 @@ def entry_point_generate_text(model_path, config_path, tokenizer_type, tokenizer generate_text_main(model_path, config_path, tokenizer, max_new_tokens, chat) -@main.command(name="create_memmap_index") +@main.group(name="data") +def data(): + """ + Collection of utilities to preprocess, analyse and modify training data. + """ + pass + + +@data.command(name="create_raw_index") @click.argument("src_path", type=Path) @click.option( "--index_path", @@ -91,7 +88,13 @@ def entry_point_generate_text(model_path, config_path, tokenizer_type, tokenizer default=None, help="output path for index. will use parent directory of src_path if none.", ) -def entry_point_create_memmap_index(src_path, index_path): +def entry_point_data_create_raw_index(src_path, index_path): + """ + Utility for indexing a large jsonl-file's content. + Background is the ability to further process the respective file without loading it, + while splitting its content line-based. This step is necessary in advance of further processing like tokenization. + It is only necessary once for a jsonl-file and allows therefore different tokenizations without re-indexing. + """ index_path = LargeFileLinesReader.default_index_path(src_path, index_path) if index_path.exists(): raise ValueError("index already exists. delete it or specify different output folder.") @@ -102,7 +105,7 @@ def entry_point_create_memmap_index(src_path, index_path): generator.create_index(index_path) -@main.command(name="create_packed_data") +@data.command(name="pack_encoded_data") @click.argument("src_path", type=Path) @click.option( "--dst_path", @@ -137,7 +140,21 @@ def entry_point_create_memmap_index(src_path, index_path): default=".text", help="jq pattern to extract the data from the json line.", ) -def entry_point_create_packed_data(src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern): +@click.option( + "--num-cpus", + type=int, + show_default=True, + default=os.cpu_count(), + help="Specify the number of tokenization workers. Default is the number of available CPUs.", +) +def entry_point_pack_encoded_data(src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern, num_cpus): + """ + Utility to encode an indexed, large jsonl-file. + + (see also `create_index` for more information) + Returns .pbin-file, which can be inserted into a training process directly + and does not require its original jsonl-file or the respective index file anymore. + """ # TODO: if we want to use alternative entrypoints together with the ResolverRegistry, # we can currently not rely on the existing class resolver. # This is based on its connection to the overall `AppConfig`. @@ -145,205 +162,139 @@ def entry_point_create_packed_data(src_path, dst_path, index_path, tokenizer_typ # This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing # ResolverRegistry to work dynamically with any type-hinted config object from config.py. tokenizer = tokenizer_type.value(tokenizer_file=str(tokenizer_file)) - generator = PackedDataGenerator(src_path, index_path=index_path, tokenizer=tokenizer, jq_pattern=jq_pattern) + generator = PackedDataGenerator( + src_path, + index_path=index_path, + tokenizer=tokenizer, + jq_pattern=jq_pattern, + number_of_processes=num_cpus, + ) generator.run(dst_path) -def load_app_config_dict(config_file_path: Path) -> Dict: - cfg = OmegaConf.load(config_file_path) - logging.info(f"Config\n {OmegaConf.to_yaml(cfg, resolve=True)}") - return OmegaConf.to_container(cfg, resolve=True) +@data.command(name="merge_packed_data") +@click.argument("src_paths", type=click.types.Path(exists=True, path_type=Path), nargs=-1, required=True) +@click.argument("target_path", type=click.types.Path(file_okay=False, dir_okay=False, path_type=Path)) +def entry_point_merge_packed_data(src_paths, target_path): + """ + Utility for merging different pbin-files into one. + This is especially useful, if different datasets were at different points in time or if one encoding takes so long, + that the overall process was done in chunks. + It is important that the same tokenizer got used for all chunks. + + Specify an arbitrary amount of pbin-files and/or directory containing such as input. + """ + input_files = [] + for p in src_paths: + p: Path + if p.is_dir(): + input_files.extend(p.glob("**/*.pbin")) + else: + input_files.append(p) + embedded_datasets = list(map(EmbeddedStreamData, input_files)) + join_embedded_stream_data(embedded_datasets, target_path) class Main: - def __init__(self, config: AppConfig) -> None: - self.config = config - self.experiment_id = get_date_of_run() - - self.resolvers = ResolverRegister(config=config) - self.running_env: RunningEnv = self.resolvers.build_component_by_config(config=self.config.running_env) + def __init__(self, config_dict: Dict, config_path: Path) -> None: + self.config_dict = config_dict + self.config_path = config_path + + self.registry = Registry(COMPONENTS) + self.component_factory = ComponentFactory(registry=self.registry) + + def add_custom_component(self, component_key: str, variant_key: str, custom_component, custom_config) -> None: + self.registry.add_entity( + component_key=component_key, + variant_key=variant_key, + component_type=custom_component, + component_config_type=custom_config, + ) def run(self): - with self.running_env as running_env: - ( - gym, - train_dataloader, - eval_data_loaders, - checkpointing, - wrapped_model, - optimizer, - ) = self.construct_components(resolvers=self.resolvers, config=self.config, running_env=running_env) - - logging.info(f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters.") - - gym.run( - callback_interval_in_batches=self.config.training.callback_interval_in_batches_per_rank, - train_data_loader=train_dataloader, - evaluation_data_loaders=eval_data_loaders, - checkpointing=checkpointing, - model=wrapped_model, - optimizer=optimizer, + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + components: ComponentsModel = self.component_factory.build_components( + config_dict=self.config_dict, components_model_type=ComponentsModel ) - def construct_components( - self, resolvers: ResolverRegister, config: AppConfig, running_env: RunningEnv - ) -> Tuple[Gym, LLMDataLoader, List[LLMDataLoader], CheckpointingIF, nn.Module, Optimizer]: - # Checkpointing - - checkpointing = CheckpointingFactory.get_checkpointing( - resolvers=self.resolvers, - config=config.checkpointing, - running_env=running_env, - experiment_id=self.experiment_id, - num_ranks=config.training.world_size, - ) - - # Model and optimizer - wrapped_model, optimizer = self.get_model_and_optimizer( - config=config, running_env=running_env, checkpointing=checkpointing - ) - if config.training.do_apply_activation_checkpointing: - apply_activation_checkpointing_inplace(wrapped_model) - logging.info("Applied activation checkpointing!") - - # Loss function - loss_fun: Loss = resolvers.build_component_by_config(config=config.loss) - - # Dataloaders - # skip_num_samples = 0 - # if run_mode == RunMode.WARM_START: - # skip_num_samples = config.modalities_setup.settings.checkpoint_num_seen_samples - - skip_num_local_train_batches = config.training.skip_num_local_train_batches - train_dataloader = DataloaderFactory.get_dataloader( - resolvers=resolvers, config=config.data.train_dataloader, skip_num_batches=skip_num_local_train_batches - ) - eval_dataloaders = [ - DataloaderFactory.get_dataloader(resolvers=resolvers, config=dataloader_config) - for dataloader_config in config.data.eval_dataloaders - ] - - # Logging - eval_split_lengths = { - dataloader.dataloader_tag: len(dataloader) * config.training.world_size * dataloader.sampler_batch_size - for dataloader in eval_dataloaders - } - - # TODO: check why not *config.training.world_size - # and consider just using config.training.num_training_samples for progress Subscriber - train_split_lengths = { - train_dataloader.dataloader_tag: (len(train_dataloader) + skip_num_local_train_batches) - * config.training.world_size - * train_dataloader.sampler_batch_size - } - - evaluation_result_publisher, batch_processed_publisher = self.get_logging_publishers( - config=config, train_split_lengths=train_split_lengths, eval_split_lengths=eval_split_lengths - ) - - # Trainer - trainer = Trainer( - local_rank=config.training.local_rank, - batch_progress_publisher=batch_processed_publisher, - evaluation_result_publisher=evaluation_result_publisher, - gradient_acc_step=config.training.gradient_acc_step, - ) - - # Evaluator - evaluator = Evaluator( - local_rank=config.training.local_rank, - batch_progress_publisher=batch_processed_publisher, - evaluation_result_publisher=evaluation_result_publisher, - ) - - # Gym - gym = Gym(trainer=trainer, evaluator=evaluator, loss_fun=loss_fun, num_ranks=config.training.world_size) - - return gym, train_dataloader, eval_dataloaders, checkpointing, wrapped_model, optimizer - - def get_model_and_optimizer( - self, config: AppConfig, running_env: RunningEnv, checkpointing: Checkpointing - ) -> Tuple[nn.Module, Optimizer]: - run_mode = config.modalities_setup.run_mode - - model: torch.nn.Module = self.resolvers.build_component_by_config(config=config.model) - - if run_mode == RunMode.WARM_START: - warm_start_settings: ModalitiesSetupConfig.WarmStartSettings = config.modalities_setup.settings - wrapped_model = checkpointing.load_model_checkpoint( - file_path=warm_start_settings.checkpoint_model_path, - model=model, + # save the config file to the checkpointing path + if components.settings.cuda_env.global_rank == 0: + experiment_path = components.settings.paths.checkpointing_path / components.settings.experiment_id + os.makedirs(experiment_path, exist_ok=True) + shutil.copy(self.config_path, experiment_path / self.config_path.name) + + evaluation_result_publisher, batch_processed_publisher = self.get_logging_publishers( + progress_subscriber=components.batch_progress_subscriber, + results_subscriber=components.evaluation_subscriber, + global_rank=components.settings.cuda_env.global_rank, + local_rank=components.settings.cuda_env.local_rank, ) - optimizer: torch.optim.Optimizer = self.resolvers.build_component_by_config( - config=config.optimizer, extra_kwargs=dict(params=wrapped_model.parameters()) + # Trainer + trainer = Trainer( + local_rank=components.settings.cuda_env.local_rank, + batch_progress_publisher=batch_processed_publisher, + evaluation_result_publisher=evaluation_result_publisher, + gradient_acc_steps=components.settings.training.gradient_acc_steps, ) - # TODO improve this - if warm_start_settings.checkpoint_optimizer_path is None: - raise ( - NotImplementedError( - "So far we always have to provide an optimizer checkpoint. " - "For fine-tuning a pre-trained, we might not want to load " - "an optimizer checkpoint." - ) - ) - - optimizer = checkpointing.load_optimizer_checkpoint( - optimizer=optimizer, model=wrapped_model, file_path=warm_start_settings.checkpoint_optimizer_path + # Evaluator + evaluator = Evaluator( + local_rank=components.settings.cuda_env.local_rank, + batch_progress_publisher=batch_processed_publisher, + evaluation_result_publisher=evaluation_result_publisher, ) - else: - wrapped_model = running_env.wrap_model(model=model, sync_module_states=False) - optimizer: torch.optim.Optimizer = self.resolvers.build_component_by_config( - config=config.optimizer, extra_kwargs=dict(params=wrapped_model.parameters()) + # Gym + gym = Gym( + trainer=trainer, + evaluator=evaluator, + loss_fun=components.loss_fn, + num_ranks=components.settings.cuda_env.world_size, ) + wrapped_model = components.wrapped_model + logging.info(f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters.") + + if components.settings.training.do_apply_activation_checkpointing: + apply_activation_checkpointing_inplace(wrapped_model) - # TODO implement scheduler - # scheduler = self.resolvers.build_component_by_config( - # config=config.scheduler, extra_kwargs=dict(optimizer=self.optimizer) - # ) + callback_interval_in_batches_per_rank = get_callback_interval_in_batches_per_rank( + callback_interval_in_samples=components.settings.training.callback_interval_in_samples, + local_train_micro_batch_size=components.settings.training.local_train_micro_batch_size, + gradient_acc_steps=components.settings.training.gradient_acc_steps, + world_size=components.settings.cuda_env.world_size, + ) - return wrapped_model, optimizer + gym.run( + callback_interval_in_batches=callback_interval_in_batches_per_rank, + train_data_loader=components.train_dataloader, + evaluation_data_loaders=components.eval_dataloaders, + checkpointing=components.checkpointing, + model=wrapped_model, + optimizer=components.optimizer, + ) + print("done") def get_logging_publishers( - self, config: AppConfig, train_split_lengths: Dict[str, int], eval_split_lengths: Dict[str, int] + self, + progress_subscriber: MessageSubscriberIF[BatchProgressUpdate], + results_subscriber: MessageSubscriberIF[EvaluationResultBatch], + global_rank: int, + local_rank: int, ) -> Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[BatchProgressUpdate],]: - # Message Broker message_broker = MessageBroker() batch_processed_publisher = MessagePublisher[BatchProgressUpdate]( message_broker=message_broker, - global_rank=config.training.global_rank, - local_rank=config.training.local_rank, + global_rank=global_rank, + local_rank=local_rank, ) evaluation_result_publisher = MessagePublisher[EvaluationResultBatch]( message_broker=message_broker, - global_rank=config.training.global_rank, - local_rank=config.training.local_rank, + global_rank=global_rank, + local_rank=local_rank, ) - # TODO make logging rank configurable - # TODO: make this instantiation of subscribers configurable via config.yml and use "build_component_by_config" - if config.training.global_rank == 0: - progress_subscriber = RichProgressSubscriber( - num_ranks=config.training.world_size, - train_split_num_samples=train_split_lengths, - eval_splits_num_samples=eval_split_lengths, - ) - evaluation_result_subscriber = WandBEvaluationResultSubscriber( - num_ranks=config.training.world_size, - project=config.wandb.project_name, - experiment_id=self.experiment_id, - mode=config.wandb.mode, - dir=config.wandb.dir, - experiment_config=config, - ) - message_broker.add_subscriber( - subscription=MessageTypes.EVALUATION_RESULT, subscriber=evaluation_result_subscriber - ) - - else: - progress_subscriber = DummyProgressSubscriber() + message_broker.add_subscriber(subscription=MessageTypes.EVALUATION_RESULT, subscriber=results_subscriber) message_broker.add_subscriber( subscription=MessageTypes.BATCH_PROGRESS_UPDATE, subscriber=progress_subscriber, diff --git a/src/modalities/activation_checkpointing.py b/src/modalities/activation_checkpointing.py index a6a4fd9a..288dc09f 100644 --- a/src/modalities/activation_checkpointing.py +++ b/src/modalities/activation_checkpointing.py @@ -8,11 +8,11 @@ ) from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP -from modalities.models.gpt2.gpt2_model import Block +from modalities.models.gpt2.gpt2_model import GPT2Block def is_module_to_apply_activation_checkpointing(submodule: torch.nn.Module) -> bool: - return isinstance(submodule, Block) + return isinstance(submodule, GPT2Block) def apply_activation_checkpointing_inplace(model: torch.nn.Module): diff --git a/src/modalities/batch.py b/src/modalities/batch.py index bc6c62c0..7cf3f34e 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -103,12 +103,15 @@ class EvaluationResultBatch(Batch): losses: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) throughput_metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) + def __str__(self) -> str: eval_str = ( f"Evaluation result on dataset tag {self.dataloader_tag} after {self.global_train_sample_id + 1} samples:" ) eval_str += "\n\nlosses: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.losses.items()]) eval_str += "\n\nmetrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.metrics.items()]) - eval_str += "\n\nthroughput metrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()]) + eval_str += "\n\nthroughput metrics: " + "\n\t".join( + [f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()] + ) eval_str += "\n===============================================" return eval_str diff --git a/src/modalities/checkpointing/checkpointing.py b/src/modalities/checkpointing/checkpointing.py index 611e8e48..6d5d80b5 100644 --- a/src/modalities/checkpointing/checkpointing.py +++ b/src/modalities/checkpointing/checkpointing.py @@ -45,11 +45,9 @@ def __init__( self, checkpointing_strategy: CheckpointingStrategyIF, checkpointing_execution: CheckpointingExecutionIF, - num_ranks: int, ): self.checkpointing_strategy = checkpointing_strategy self.checkpointing_execution = checkpointing_execution - self.num_ranks = num_ranks def save_checkpoint( self, @@ -76,10 +74,10 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: model = self.checkpointing_execution.load_model_checkpoint(model=model, file_path=file_path) return model - def load_optimizer_checkpoint(self, optimizer: Optimizer, model: nn.Module, file_path: Path) -> Optimizer: + def load_optimizer_checkpoint(self, optimizer: Optimizer, wrapped_model: nn.Module, file_path: Path) -> Optimizer: optimizer = self.checkpointing_execution.load_optimizer_checkpoint( optimizer=optimizer, - model=model, + wrapped_model=wrapped_model, file_path=file_path, ) return optimizer diff --git a/src/modalities/checkpointing/checkpointing_execution.py b/src/modalities/checkpointing/checkpointing_execution.py index 4b94c414..cfe89306 100644 --- a/src/modalities/checkpointing/checkpointing_execution.py +++ b/src/modalities/checkpointing/checkpointing_execution.py @@ -1,18 +1,19 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Callable, List +from typing import List import torch import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp import ShardingStrategy, StateDictType from torch.optim import Optimizer from modalities.checkpointing.checkpointing_instruction import CheckpointingInstruction from modalities.exceptions import CheckpointingError +from modalities.running_env.env_utils import MixedPrecisionSettings class CheckpointingEntityType(Enum): @@ -29,7 +30,7 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: def load_optimizer_checkpoint( self, optimizer: Optimizer, - model: nn.Module, + wrapped_model: nn.Module, file_path: Path, ) -> Optimizer: raise NotImplementedError @@ -76,7 +77,9 @@ def __init__( checkpoint_path: Path, experiment_id: str, global_rank: int, - model_wrapping_fn: Callable[[nn.Module, bool], FSDP], + block_names: List[str], + mixed_precision_settings: MixedPrecisionSettings, + sharding_strategy: ShardingStrategy, ): """ Implementation of checkpointing to disc via FSDP @@ -85,13 +88,13 @@ def __init__( checkpoint_path (Path): folder path to the checkpoint experiment_id (str): ID of the experiment global_rank (int): global rank within the current process group - model_wrapping_fn (Callable[[nn.Module, bool], FSDP]): Wrapping function that wraps raw model. - For FSDP, we pass in FSDPRunningEnv.wrap_model """ self.checkpoint_path = checkpoint_path self.global_rank = global_rank - self.model_wrapping_fn = model_wrapping_fn self.experiment_id = experiment_id + self.block_names = block_names + self.mixed_precision_settings = mixed_precision_settings + self.sharding_strategy = sharding_strategy def _get_checkpointing_path( self, @@ -186,10 +189,20 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: # load model on rank 0 into CPU RAM model_state = torch.load(file_path) model.load_state_dict(model_state) - fsdp_model = self.model_wrapping_fn(model=model, sync_module_states=True) + + # TODO nasty workaround to prevent circular imports + from modalities.models.model_factory import ModelFactory + + fsdp_model = ModelFactory.get_fsdp_wrapped_model( + model=model, + sync_module_states=True, + block_names=self.block_names, + mixed_precision_settings=self.mixed_precision_settings, + sharding_strategy=self.sharding_strategy, + ) return fsdp_model - def load_optimizer_checkpoint(self, optimizer: Optimizer, model: FSDP, file_path: Path) -> Optimizer: + def load_optimizer_checkpoint(self, optimizer: Optimizer, wrapped_model: FSDP, file_path: Path) -> Optimizer: # load optimizer full_optimizer_state_dict = None if self.global_rank == 0: @@ -198,7 +211,7 @@ def load_optimizer_checkpoint(self, optimizer: Optimizer, model: FSDP, file_path # distribute the optimizer state dict from rank 0 to all the other ranks sharded_optimizer_state_dict = FSDP.scatter_full_optim_state_dict( - full_optim_state_dict=full_optimizer_state_dict, model=model, group=None + full_optim_state_dict=full_optimizer_state_dict, model=wrapped_model, group=None ) optimizer.load_state_dict(sharded_optimizer_state_dict) diff --git a/src/modalities/checkpointing/checkpointing_factory.py b/src/modalities/checkpointing/checkpointing_factory.py deleted file mode 100644 index 138569be..00000000 --- a/src/modalities/checkpointing/checkpointing_factory.py +++ /dev/null @@ -1,36 +0,0 @@ -from modalities.checkpointing.checkpointing import ( - Checkpointing, - CheckpointingExecutionIF, - CheckpointingIF, - CheckpointingStrategyIF, -) -from modalities.config.config import CheckpointingConfig -from modalities.resolver_register import ResolverRegister -from modalities.running_env.fsdp.fsdp_running_env import RunningEnv - - -class CheckpointingFactory: - @staticmethod - def get_checkpointing( - resolvers: ResolverRegister, - config: CheckpointingConfig, - running_env: RunningEnv, - experiment_id: str, - num_ranks: int, - ) -> CheckpointingIF: - checkpointing_strategy: CheckpointingStrategyIF = resolvers.build_component_by_config( - config=config.checkpointing_strategy, extra_kwargs={} - ) - - checkpointing_execution: CheckpointingExecutionIF = resolvers.build_component_by_config( - config=config.checkpointing_execution, - extra_kwargs={"experiment_id": experiment_id, "model_wrapping_fn": running_env.wrap_model}, - ) - - checkpointing = Checkpointing( - checkpointing_strategy=checkpointing_strategy, - checkpointing_execution=checkpointing_execution, - num_ranks=num_ranks, - ) - - return checkpointing diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py new file mode 100644 index 00000000..c3a3dfd5 --- /dev/null +++ b/src/modalities/config/component_factory.py @@ -0,0 +1,144 @@ +from typing import Any, Dict, List, Type, TypeVar, Union + +from pydantic import BaseModel + +from modalities.registry.registry import Registry + + +class ComponentFactory: + def __init__(self, registry: Registry) -> None: + self.registry = registry + + BaseModelChild = TypeVar("BaseModelChild", bound=BaseModel) + + def build_components(self, config_dict: Dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild: + component_names = list(components_model_type.model_fields.keys()) + component_dict = self._build_config(config_dict=config_dict, component_names=component_names) + print(component_dict) + components = components_model_type(**component_dict) + return components + + def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[str, Any]: + component_dict_filtered = {name: config_dict[name] for name in component_names} + components, _ = self._build_component( + current_component_config=component_dict_filtered, + component_config=config_dict, + top_level_components={}, + traversal_path=[], + ) + return components + + def _build_component( + self, + current_component_config: Union[Dict, List, Any], + component_config: Union[Dict, List, Any], + top_level_components: Dict[str, Any], + traversal_path: List, + ) -> Any: + # build sub components first via recursion + if isinstance(current_component_config, dict): + # if the entities are top level components, we return the component, + # as it must have been built already via a referencing component + if len(traversal_path) > 0 and traversal_path[-1] in top_level_components: + entity_key = traversal_path[-1] + return top_level_components[entity_key], top_level_components + # if it is not a component that has been built already, we need to build it. + # We first traverse the config for possible sub components that need to build beforehand. + materialized_component_config = {} + for sub_entity_key, sub_component_config_dict in current_component_config.items(): + materialized_component_config[sub_entity_key], top_level_components = self._build_component( + current_component_config=sub_component_config_dict, + component_config=component_config, + top_level_components=top_level_components, + traversal_path=traversal_path + [sub_entity_key], + ) + # After building all the sub components, we can now build the actual component + # if the config is component_config then we instantiate the component + if ComponentFactory._is_component_config(config_dict=current_component_config): + # instantiate component config + component_key = current_component_config["component_key"] + variant_key = current_component_config["variant_key"] + current_component_config = self._instantiate_component_config( + component_key=component_key, + variant_key=variant_key, + config_dict=materialized_component_config["config"], + ) + # instantiate component + component = self._instantiate_component( + component_key=component_key, variant_key=variant_key, component_config=current_component_config + ) + print(" -> ".join(traversal_path) + ":", component) + + # if the component is a top level component, then we add it to the top level components dictionary + # to make sure that we don't build it again. Building it again would mean that we work by-value + # instead of by reference. + if len(traversal_path) == 1: + entity_key = traversal_path[-1] + top_level_components[entity_key] = component + return component, top_level_components + + # if the config is a reference_config then check if it exists and if not, we build it + if ComponentFactory._is_reference_config(config_dict=current_component_config): + referenced_entity_key = current_component_config["instance_key"] + if referenced_entity_key not in top_level_components: + materialized_referenced_component, top_level_components = self._build_component( + current_component_config=component_config[referenced_entity_key], + component_config=component_config, + top_level_components=top_level_components, + traversal_path=[referenced_entity_key], + ) + # we add the newly build reference config to the top level components dict + # so that we don't instantiate it again when we reach the respective component config + # in the subsequent config traversal + top_level_components[referenced_entity_key] = materialized_referenced_component + print(" -> ".join(traversal_path) + ": ", f"--ref--> {top_level_components[referenced_entity_key]}") + return top_level_components[referenced_entity_key], top_level_components + + return materialized_component_config, top_level_components + + elif isinstance(current_component_config, list): + materialized_component_configs = [] + for sub_entity_key, sub_component_config in enumerate(current_component_config): + materialized_component_config, top_level_components = self._build_component( + current_component_config=sub_component_config, + component_config=component_config, + top_level_components=top_level_components, + traversal_path=traversal_path + [str(sub_entity_key)], + ) + materialized_component_configs.append(materialized_component_config) + return materialized_component_configs, top_level_components + + else: + # we return the raw sub config if the sub config is not a dictionary or a list + # i.e., just a "scalar" value (e.g., string, int, etc.), since we don't have to build it. + return current_component_config, top_level_components + + @staticmethod + def _is_component_config(config_dict: Dict) -> bool: + # TODO instead of field checks, we should introduce an enum for the config type. + return "component_key" in config_dict.keys() + + @staticmethod + def _is_reference_config(config_dict: Dict) -> bool: + # TODO instead of field checks, we should introduce an enum for the config type. + return {"instance_key", "pass_type"} == config_dict.keys() + + def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: Dict) -> BaseModel: + component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key) + comp_config = component_config_type(**config_dict, strict=True) + return comp_config + + def _instantiate_component(self, component_key: str, variant_key: str, component_config: BaseModel) -> Any: + component_type: Type = self.registry.get_component(component_key, variant_key) + component_config_dict = self.base_model_to_dict(component_config) + component = component_type(**component_config_dict) + return component + + @staticmethod + def base_model_to_dict(base_model: BaseModel) -> Dict: + # converts top level structure of base_model into dictionary while maintaining substructure + output = {} + for name, _ in base_model.model_fields.items(): + value = getattr(base_model, name) + output[name] = value + return output diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 0e166242..353cf14e 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -1,345 +1,336 @@ -import json -import warnings -from enum import Enum +import os from pathlib import Path -from typing import List, Optional, Union - -from pydantic import BaseModel, Field, FilePath, PositiveFloat, PositiveInt, confloat, conint, model_validator -from transformers import PretrainedConfig - -from modalities.config.lookup_types import ( - BatchSamplerTypes, - CheckpointingExectionTypes, - CheckpointingStrategyTypes, - CollatorTypes, - DataloaderTypes, - DatasetTypes, - LossTypes, - ModelTypes, - OptimizerTypes, - SamplerTypes, - SchedulerTypes, - TokenizerTypes, -) -from modalities.config.types import ProcessGroupBackendType -from modalities.models.gpt2.gpt2_model import GPT2Config -from modalities.running_env.fsdp.fsdp_running_env import RunningEnvConfig - - -class WandbConfig(BaseModel): - class WandbMode(Enum): - ONLINE = "ONLINE" - OFFLINE = "OFFLINE" - DISABLED = "DISABLED" - - project_name: str - mode: WandbMode - dir: Optional[Path] = Field(default_factory=lambda: Path(".")) +from typing import Annotated, Any, Dict, List, Optional + +import torch.nn as nn +from omegaconf import OmegaConf +from pydantic import BaseModel, Field, FilePath, GetCoreSchemaHandler, PositiveInt, field_validator +from pydantic_core import core_schema +from torch.distributed.fsdp import ShardingStrategy +from torch.optim import Optimizer +from torch.utils.data import Sampler +from torch.utils.data.dataset import Dataset +from transformers import GPT2TokenizerFast +from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from modalities.checkpointing.checkpointing import CheckpointingIF +from modalities.checkpointing.checkpointing_execution import CheckpointingExecutionIF +from modalities.checkpointing.checkpointing_strategies import CheckpointingStrategyIF +from modalities.config.lookup_enum import LookupEnum +from modalities.dataloader.dataloader import LLMDataLoader +from modalities.logging_broker.subscriber import MessageSubscriberIF +from modalities.loss_functions import Loss +from modalities.models.gpt2.collator import CollateFnIF +from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support +from modalities.util import get_date_of_run, parse_enum_by_name + + +class PydanticThirdPartyTypeIF: + def __init__(self, third_party_type): + self.third_party_type = third_party_type + + def __get_pydantic_core_schema__( + self, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + # see: https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types + return core_schema.json_or_python_schema( + json_schema=core_schema.is_instance_schema(self.third_party_type), + python_schema=core_schema.is_instance_schema(self.third_party_type), + # serialization=core_schema.plain_serializer_function_ser_schema( + # lambda instance: instance.x + # ), + ) + + +PydanticCheckpointingIFType = Annotated[CheckpointingIF, PydanticThirdPartyTypeIF(CheckpointingIF)] +PydanticCheckpointingStrategyIFType = Annotated[ + CheckpointingStrategyIF, PydanticThirdPartyTypeIF(CheckpointingStrategyIF) +] +PydanticCheckpointingExecutionIFType = Annotated[ + CheckpointingExecutionIF, PydanticThirdPartyTypeIF(CheckpointingExecutionIF) +] +PydanticModelIFType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)] +PydanticTokenizerIFType = Annotated[PreTrainedTokenizerFast, PydanticThirdPartyTypeIF(PreTrainedTokenizerFast)] +PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)] +PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)] +PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)] +PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)] +PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)] +PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)] +PydanticMessageSubscriberIFType = Annotated[MessageSubscriberIF, PydanticThirdPartyTypeIF(MessageSubscriberIF)] + + +class ProcessGroupBackendType(LookupEnum): + nccl = "nccl" + + +class TokenizerTypes(LookupEnum): + GPT2TokenizerFast = GPT2TokenizerFast + LlamaTokenizerFast = LlamaTokenizerFast + + +class PassType(LookupEnum): + BY_VALUE = "by_value" + BY_REFERENCE = "by_reference" + + +class WandbMode(LookupEnum): + ONLINE = "ONLINE" + OFFLINE = "OFFLINE" + DISABLED = "DISABLED" + + +class ReferenceConfig(BaseModel): + instance_key: str + pass_type: PassType -class CudaKwargsConfig(BaseModel): - num_workers: conint(ge=0) - pin_memory: bool - shuffle: bool +class CLMCrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + +# Checkpointing +class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): + k: PositiveInt -class TokenizerConfig(BaseModel): - class GPT2TokenizerFastConfig(BaseModel): - tokenizer_file: str # FilePath not possible, since transformers.PretrainedTokenizers can only handle strings - type_hint: TokenizerTypes - config: GPT2TokenizerFastConfig +class SaveKMostRecentCheckpointsStrategyConfig(BaseModel): + k: Annotated[int, Field(strict=True, ge=-1)] -class DatasetConfig(BaseModel): - class MemMapDatasetConfig(BaseModel): - raw_data_path: FilePath - index_path: Optional[FilePath] = None - block_size: conint(gt=0) - tokenizer: TokenizerConfig - jq_pattern: str - sample_key: str +class FSDPToDiscCheckpointingConfig(BaseModel): + checkpoint_path: Path + global_rank: Annotated[int, Field(strict=True, ge=0)] + experiment_id: str + block_names: List[str] + mixed_precision_settings: MixedPrecisionSettings + sharding_strategy: ShardingStrategy - class PackedMemMapDatasetContinuousConfig(BaseModel): - raw_data_path: Path - block_size: conint(gt=0) - sample_key: str + @field_validator("mixed_precision_settings", mode="before") + def parse_mixed_precision_setting_by_name(cls, name): + mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name( + name=name, enum_type=MixedPrecisionSettings + ) + if not has_bfloat_support() and ( + mixed_precision_settings == MixedPrecisionSettings.BF_16 + or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING + ): + raise ValueError("BF16 not supported in the current environment") + return mixed_precision_settings - class PackedMemMapDatasetMegatronConfig(BaseModel): - raw_data_path: Path - block_size: conint(gt=0) - sample_key: str + @field_validator("sharding_strategy", mode="before") + def parse_sharding_strategy_by_name(cls, name): + return parse_enum_by_name(name=name, enum_type=ShardingStrategy) - class MMapIndexedDatasetConfig(BaseModel): - path: Path - skip_warmup: bool - class OpenGPTXMMapDatasetConfig(BaseModel): - num_samples: conint(ge=1) - path: FilePath - sample_key: str - sequence_len: PositiveInt +class CheckpointingConfig(BaseModel): + checkpointing_strategy: PydanticCheckpointingStrategyIFType + checkpointing_execution: PydanticCheckpointingExecutionIFType - type_hint: DatasetTypes - config: Union[ - MemMapDatasetConfig, - OpenGPTXMMapDatasetConfig, - PackedMemMapDatasetContinuousConfig, - PackedMemMapDatasetMegatronConfig, - MMapIndexedDatasetConfig, - ] = Field(union_mode="left_to_right") +class AdamWOptimizerConfig(BaseModel): + lr: float + wrapped_model: PydanticModelIFType -class SamplerConfig(BaseModel): - class DistributedSamplerConfig(BaseModel): - rank: conint(ge=0) - num_replicas: conint(ge=0) - shuffle: bool - type_hint: SamplerTypes - config: DistributedSamplerConfig +class CheckpointedOptimizerConfig(BaseModel): + checkpointing: PydanticCheckpointingIFType + checkpoint_path: Path + wrapped_model: PydanticModelIFType + optimizer: PydanticOptimizerIFType -class BatchSamplerConfig(BaseModel): - class StandardBatchSamplerConfig(BaseModel): - sampler: SamplerConfig - batch_size: conint(gt=0) - drop_last: bool +class CheckpointedModelConfig(BaseModel): + checkpointing: PydanticCheckpointingIFType + checkpoint_path: Path + model: PydanticModelIFType - type_hint: BatchSamplerTypes - config: StandardBatchSamplerConfig +class FSDPWrappedModelConfig(BaseModel): + model: PydanticModelIFType + sync_module_states: bool + mixed_precision_settings: MixedPrecisionSettings + sharding_strategy: ShardingStrategy + block_names: List[str] -class CollatorConfig(BaseModel): - class GPT2LLMCollatorConfig(BaseModel): - sample_key: str - target_key: str + @field_validator("mixed_precision_settings", mode="before") + def parse_mixed_precision_setting_by_name(cls, name): + mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name( + name=name, enum_type=MixedPrecisionSettings + ) + if not has_bfloat_support() and ( + mixed_precision_settings == MixedPrecisionSettings.BF_16 + or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING + ): + raise ValueError("BF16 not supported in the current environment") + return mixed_precision_settings - type_hint: CollatorTypes - config: GPT2LLMCollatorConfig + @field_validator("sharding_strategy", mode="before") + def parse_sharding_strategy_by_name(cls, name): + return parse_enum_by_name(name=name, enum_type=ShardingStrategy) -class DataLoaderConfig(BaseModel): - class LLMDataLoaderConfig(CudaKwargsConfig): - dataloader_tag: str - dataset: DatasetConfig - batch_sampler: BatchSamplerConfig - collate_fn: CollatorConfig +class GPT2TokenizerFastConfig(BaseModel): + # Note: huggingface tokenizers expect file path as string + tokenizer_file: str - type_hint: DataloaderTypes - config: LLMDataLoaderConfig + +class DistributedSamplerConfig(BaseModel): + rank: Annotated[int, Field(strict=True, ge=0)] + num_replicas: Annotated[int, Field(strict=True, ge=0)] + shuffle: bool + dataset: PydanticDatasetIFType -class DataConfig(BaseModel): +class MemMapDatasetConfig(BaseModel): + raw_data_path: FilePath + index_path: Optional[FilePath] = None + block_size: Annotated[int, Field(strict=True, gt=0)] + tokenizer: PydanticTokenizerIFType + jq_pattern: str sample_key: str - target_key: str - sequence_len: int - train_dataloader: DataLoaderConfig - eval_dataloaders: List[DataLoaderConfig] -class ModelConfig(BaseModel): - type_hint: ModelTypes - config: GPT2Config +class PackedMemMapDatasetContinuousConfig(BaseModel): + raw_data_path: Path + block_size: Annotated[int, Field(strict=True, gt=0)] + sample_key: str -class CLMCrossEntropyLossConfig(BaseModel): +class PackedMemMapDatasetMegatronConfig(BaseModel): + raw_data_path: Path + block_size: Annotated[int, Field(strict=True, gt=0)] + sample_key: str + + +class MMapIndexedDatasetConfig(BaseModel): + path: Path + skip_warmup: bool + + +class OpenGPTXMMapDatasetConfig(BaseModel): + num_samples: Annotated[int, Field(strict=True, ge=1)] + path: FilePath + sample_key: str + sequence_len: PositiveInt + + +class BatchSamplerConfig(BaseModel): + sampler: PydanticSamplerIFType + batch_size: Annotated[int, Field(strict=True, gt=0)] + drop_last: bool + + +class ResumableBatchSamplerConfig(BaseModel): + sampler: PydanticSamplerIFType + start_index: Annotated[int, Field(strict=True, gt=0)] + + +class GPT2LLMCollateFnConfig(BaseModel): + sample_key: str target_key: str - prediction_key: str -class LossConfig(BaseModel): - type_hint: LossTypes - config: CLMCrossEntropyLossConfig - - -class TrainingConfig(BaseModel): - # TODO: use this in Progress Logging - global_num_training_samples: conint(gt=0) - callback_interval_in_samples: conint(gt=0) - process_group_backend: ProcessGroupBackendType - local_rank: conint(ge=0) - global_rank: conint(ge=0) - world_size: conint(ge=0) - main_rank: conint(ge=0) - local_train_micro_batch_size: conint(gt=0) - global_num_seen_samples: conint(ge=0) - do_apply_activation_checkpointing: bool - gradient_acc_step: conint(gt=0) - - @property - def local_train_batch_size(self): - return self.local_train_micro_batch_size * self.gradient_acc_step - - @property - def global_train_batch_size(self): - return self.local_train_batch_size * self.world_size - - @property - def local_num_train_samples(self): - exact = self.global_num_training_samples / self.world_size - ret = self.global_num_training_samples // self.world_size - if exact != ret: - print(f"Calculated local_num_training_samples is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def local_num_seen_train_samples(self): - exact = self.global_num_seen_samples / self.world_size - ret = self.global_num_seen_samples // self.world_size - if exact != ret: - print(f"Calculated global_num_seen_samples is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def skip_num_local_train_batches(self) -> int: - exact = self.global_num_seen_samples / self.world_size / self.local_train_micro_batch_size - ret = self.global_num_seen_samples // self.world_size // self.local_train_micro_batch_size - if exact != ret: - print(f"Calculated skip_num_local_train_batches is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def num_training_batches(self) -> int: - exact = self.global_num_training_samples / self.local_train_micro_batch_size - ret = self.global_num_training_samples // self.local_train_micro_batch_size - if exact != ret: - warnings.warn(f"Calculated num_training_batches is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def callback_interval_in_batches_per_rank(self): - exact = self.callback_interval_in_samples / self.local_train_micro_batch_size / self.world_size - ret = max(self.callback_interval_in_samples // self.local_train_micro_batch_size // self.world_size, 1) - if exact != ret: - warnings.warn( - f"Calculated callback_interval_in_batches_per_rank is not an integer. Clipping {exact} to {ret} " - ) - return ret - - -class AdamWConfig(BaseModel): - lr: confloat(ge=0.0) - - -class OptimizerConfig(BaseModel): - type_hint: OptimizerTypes - config: AdamWConfig - - -class OneCycleLRConfig(BaseModel): - max_lr: PositiveFloat - total_steps: conint(ge=1) - pct_start: confloat(ge=0.0) - anneal_strategy: str - cycle_momentum: bool - base_momentum: float | List - max_momentum: float | List - div_factor: PositiveFloat - final_div_factor: PositiveFloat - three_phase: bool - last_epochs: int - verbose: bool - - -class StepLRConfig(BaseModel): - step_size: conint(ge=1) - gamma: confloat(ge=0.0) - - -class ConstantLRConfig(BaseModel): - factor: PositiveFloat - total_iters: PositiveInt - - -class SchedulerConfig(BaseModel): - type_hint: SchedulerTypes - config: StepLRConfig | ConstantLRConfig | OneCycleLRConfig +class LLMDataLoaderConfig(BaseModel): + dataloader_tag: str + dataset: PydanticDatasetIFType + batch_sampler: PydanticSamplerIFType + collate_fn: PydanticCollateFnIFType + num_workers: Annotated[int, Field(strict=True, ge=0)] + pin_memory: bool + shuffle: bool + skip_num_batches: Optional[int] = 0 -class CheckpointingConfig(BaseModel): - class CheckpointingStrategyConfig(BaseModel): - class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): - k: PositiveInt - - class SaveKMostRecentCheckpointsStrategyConfig(BaseModel): - k: conint(ge=-1) - - type_hint: CheckpointingStrategyTypes - config: SaveEveryKStepsCheckpointingStrategyConfig | SaveKMostRecentCheckpointsStrategyConfig - - class CheckpointingExecutionConfig(BaseModel): - class FSDPToDiscCheckpointingConfig(BaseModel): - checkpoint_path: Path - global_rank: conint(ge=0) - - type_hint: CheckpointingExectionTypes - config: FSDPToDiscCheckpointingConfig - - checkpointing_strategy: CheckpointingStrategyConfig - checkpointing_execution: CheckpointingExecutionConfig - - -class RunMode(Enum): - FROM_SCRATCH = "FROM_SCRATCH" - WARM_START = "WARM_START" - -class ModalitiesSetupConfig(BaseModel): - class WarmStartSettings(BaseModel): - checkpoint_model_path: Path - global_num_seen_samples: conint(gt=0) - checkpoint_optimizer_path: Optional[Path] = None - checkpoint_lr_scheduler_path: Optional[Path] = None - - class FromScratchSettings(BaseModel): - global_num_seen_samples: int = 0 - - run_mode: RunMode - settings: FromScratchSettings - # settings: WarmStartSettings - - @model_validator(mode="after") - def check_passwords_match(self) -> "ModalitiesSetupConfig": - if self.run_mode == RunMode.FROM_SCRATCH: - if self.settings.global_num_seen_samples != 0: - raise ValueError("When starting from scratch, global_num_seen_samples must be 0.") - return self - - -class AppConfig(BaseModel): - modalities_setup: ModalitiesSetupConfig - data: DataConfig - training: TrainingConfig - running_env: RunningEnvConfig - model: ModelConfig - optimizer: OptimizerConfig - scheduler: SchedulerConfig - checkpointing: CheckpointingConfig - wandb: WandbConfig - loss: LossConfig - - -class PretrainedGPTConfig(PretrainedConfig): - model_type = "modalities_gpt2" - - def __init__(self, config: GPT2Config = None, **kwargs): - if type(config) == dict: - config = GPT2Config(**config) - self.config = config - - super().__init__(**kwargs) - - def to_json_string(self, use_diff: bool = True) -> str: - if self.config: - json_dict = {"config": self.config.__dict__.copy(), "model_type": self.model_type} - json_dict["config"]["attention"] = { - "attention_type": self.config.attention.attention_type.value, - "scaling_factor": self.config.attention.scaling_factor, - } - json_dict["config"]["weight_init"] = { - "mean": self.config.weight_init.mean, - "std": self.config.weight_init.std, - } - else: - json_dict = {} - return json.dumps(json_dict) +class DummyProgressSubscriberConfig(BaseModel): + pass + + +class RichProgressSubscriberConfig(BaseModel): + train_dataloader: PydanticLLMDataLoaderIFType + eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + world_size: int + global_num_seen_samples: int + local_rank: int + + +class DummyResultSubscriberConfig(BaseModel): + pass + + +class WandBEvaluationResultSubscriberConfig(BaseModel): + local_rank: int + project: str + experiment_id: str + mode: WandbMode + directory: Path + experiment_config: Optional[Dict] = None + + +class RichResultSubscriberConfig(BaseModel): + num_ranks: int + local_rank: int + + +class CudaEnv(BaseModel): + local_rank: Annotated[int, Field(strict=True, ge=0)] + world_size: Annotated[int, Field(strict=True, ge=1)] + global_rank: Annotated[int, Field(strict=True, ge=0)] + + +class Settings(BaseModel): + class Training(BaseModel): + callback_interval_in_samples: Annotated[int, Field(strict=True, ge=1)] + global_num_training_samples: Annotated[int, Field(strict=True, ge=1)] + global_num_seen_samples: Annotated[int, Field(strict=True, ge=0)] + do_apply_activation_checkpointing: bool + gradient_acc_steps: Annotated[int, Field(strict=True, ge=1)] + local_train_micro_batch_size: Annotated[int, Field(strict=True, ge=1)] + sequence_length: Annotated[int, Field(strict=True, ge=1)] + + class Paths(BaseModel): + checkpointing_path: Path + + experiment_id: str + referencing_keys: Dict[str, str] + training: Training + cuda_env: CudaEnv + paths: Paths + + +class ComponentsModel(BaseModel): + wrapped_model: PydanticModelIFType + optimizer: PydanticOptimizerIFType + loss_fn: PydanticLossIFType + train_dataloader: PydanticLLMDataLoaderIFType + eval_dataloaders: List[PydanticLLMDataLoaderIFType] + batch_progress_subscriber: PydanticMessageSubscriberIFType + evaluation_subscriber: PydanticMessageSubscriberIFType + checkpointing: PydanticCheckpointingIFType + settings: Settings + + +class ComponentsInferenceModel(BaseModel): + wrapped_model: PydanticModelIFType + cuda_env: CudaEnv + + +def load_app_config_dict(config_file_path: Path) -> Dict: + def cuda_env_resolver_fun(var_name: str) -> int: + int_env_variable_names = ["LOCAL_RANK", "WORLD_SIZE", "RANK"] + return int(os.getenv(var_name)) if var_name in int_env_variable_names else os.getenv(var_name) + + def modalities_env_resolver_fun(var_name: str) -> int: + if var_name == "experiment_id": + return get_date_of_run() + + OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True) + OmegaConf.register_new_resolver("modalities_env", modalities_env_resolver_fun, replace=True) + + cfg = OmegaConf.load(config_file_path) + config_dict = OmegaConf.to_container(cfg, resolve=True) + return config_dict diff --git a/src/modalities/config/lookup_enum.py b/src/modalities/config/lookup_enum.py new file mode 100644 index 00000000..1e033735 --- /dev/null +++ b/src/modalities/config/lookup_enum.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LookupEnum(Enum): + @classmethod + def _missing_(cls, value: str) -> type: + """constructs Enum by member name, if not constructable by value""" + return cls.__dict__[value] diff --git a/src/modalities/config/lookup_types.py b/src/modalities/config/lookup_types.py deleted file mode 100644 index 46147480..00000000 --- a/src/modalities/config/lookup_types.py +++ /dev/null @@ -1,83 +0,0 @@ -from enum import Enum - -import torch -from torch.utils.data import BatchSampler, DistributedSampler -from transformers import GPT2TokenizerFast - -from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing -from modalities.checkpointing.checkpointing_strategies import ( - SaveEveryKStepsCheckpointingStrategy, - SaveKMostRecentCheckpointsStrategy, -) -from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader -from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron -from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDatasetBuilder -from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset -from modalities.loss_functions import CLMCrossEntropyLoss -from modalities.models.gpt2.collator import GPT2LLMCollator -from modalities.models.gpt2.gpt2_model import GPT2LLM - - -class LookupEnum(Enum): - @classmethod - def _missing_(cls, value: str) -> type: - """constructs Enum by member name, if not constructable by value""" - return cls.__dict__[value] - - -class ModelTypes(LookupEnum): - GPT2LLM = GPT2LLM - - -class LossTypes(LookupEnum): - CLMCrossEntropyLoss = CLMCrossEntropyLoss - - -class OptimizerTypes(LookupEnum): - AdamW = torch.optim.AdamW - - -class SchedulerTypes(LookupEnum): - StepLR = torch.optim.lr_scheduler.StepLR - ConstantLR = torch.optim.lr_scheduler.ConstantLR - OneCycleLR = torch.optim.lr_scheduler.OneCycleLR - - -class TokenizerTypes(LookupEnum): - GPT2TokenizerFast = GPT2TokenizerFast - - -class DatasetTypes(LookupEnum): - MemMapDataset = MemMapDataset - PackedMemMapDatasetContinuous = PackedMemMapDatasetContinuous - PackedMemMapDatasetMegatron = PackedMemMapDatasetMegatron - MMapIndexedDataset = MMapIndexedDatasetBuilder - # TODO: ClassResolver does not work with functions ... therefore there is also no - # support for factories. - OpenGPTXMMapDataset = OpenGPTXMMapDataset # member(OpenGPTXDatasetFactory.create_dataset) - - -class SamplerTypes(LookupEnum): - DistributedSampler = DistributedSampler - - -class BatchSamplerTypes(LookupEnum): - BatchSampler = BatchSampler - - -class CollatorTypes(LookupEnum): - GPT2LLMCollator = GPT2LLMCollator - - -class DataloaderTypes(LookupEnum): - RepeatingDataLoader = RepeatingDataLoader - LLMDataLoader = LLMDataLoader - - -class CheckpointingStrategyTypes(LookupEnum): - SaveKMostRecentCheckpointsStrategy = SaveKMostRecentCheckpointsStrategy - SaveEveryKStepsCheckpointingStrategy = SaveEveryKStepsCheckpointingStrategy - - -class CheckpointingExectionTypes(LookupEnum): - FSDPToDiscCheckpointing = FSDPToDiscCheckpointing diff --git a/src/modalities/config/types.py b/src/modalities/config/types.py deleted file mode 100644 index 803abc45..00000000 --- a/src/modalities/config/types.py +++ /dev/null @@ -1,5 +0,0 @@ -from enum import Enum - - -class ProcessGroupBackendType(Enum): - nccl = "nccl" diff --git a/src/modalities/dataloader/create_index.py b/src/modalities/dataloader/create_index.py index 8b5e0e3c..1fc0d4d9 100644 --- a/src/modalities/dataloader/create_index.py +++ b/src/modalities/dataloader/create_index.py @@ -6,16 +6,14 @@ import warnings from pathlib import Path -import numpy as np from tqdm import tqdm -# TODO: benchmark against pyspark class IndexGenerator: def __init__(self, src_file: Path, chunksize: int = 4096, drop_faulty_entries: bool = False): """ Reads in a JSON file as a binary file, iterates character by character und builds up - the sample index (char-wisestart and end position for each JSON sample) via "\n" character positions. + the sample index (char-wise start and end position for each JSON sample) via "\n" character positions. :param src_file: Path to a jsonl-file. :param chunksize: defines the size of byte chunks that are processed via a producer-consumer approach. @@ -26,12 +24,11 @@ def __init__(self, src_file: Path, chunksize: int = 4096, drop_faulty_entries: b self.src_file = src_file self.chunksize = chunksize self.drop_faulty_entries = drop_faulty_entries - with self.src_file.open(mode="r", encoding="utf-8") as fin: + with self.src_file.open(mode="r") as fin: fin.seek(0, os.SEEK_END) - num_chars = fin.tell() - self.num_chunks = num_chars // self.chunksize - self.reminder = num_chars % self.chunksize - self._chunk_queue = queue.Queue() + self._total_num_chars = fin.tell() + self.num_chunks = self._total_num_chars // self.chunksize + self._queue_of_raw_lines = queue.Queue() self._index_map = [] self._exception_buffer = [] @@ -51,49 +48,42 @@ def create_index(self, target_path_for_index_file: Path): def _indexer_thread(self): def queue_generator(): while True: - chunk = self._chunk_queue.get() - if chunk is None: + line = self._queue_of_raw_lines.get() + if line is None: break - yield chunk + yield line - def process_line(last_index: int, curr_index: int): - segment_len = curr_index - last_index + def parse_line_as_json(line_start_idx: int, line: str): try: # check if line is a valid json - line = np.memmap(self.src_file, mode="r", offset=last_index, shape=(segment_len,)).view("S1").tolist() - line = [c.decode("utf8") for c in line] - line = "".join(line) json.loads(line) - self._index_map.append((last_index, segment_len)) + self._index_map.append((line_start_idx, len(line))) except Exception as low_level_err: if self.drop_faulty_entries: - warnings.warn(f"faulty line at {last_index}-{curr_index}, skipping...") + warnings.warn(f'faulty line "{line}", skipping...') else: - warnings.warn(f"faulty line: {line=}") - err = ValueError(f"faulty line at {last_index}-{curr_index}") + err = ValueError(f'faulty line "{line}", skipping...') err.__cause__ = low_level_err self._exception_buffer.append(err) self._index_map = [] - last_index = 0 - for chunk_idx, chunk in tqdm(enumerate(queue_generator()), desc="Processed Chunks", total=self.num_chunks): - for char_index, c in enumerate(chunk): - curr_index = chunk_idx * self.chunksize + char_index - if c == ord("\n"): - process_line(last_index, curr_index) - last_index = curr_index + 1 - # prevents automatically added "\n"-chars at the end of files getting interpreted as own sample - if curr_index >= last_index: - process_line(last_index, curr_index + 1) + for line_start_idx, line in tqdm(queue_generator(), desc="Processed Lines"): + if self._check_for_parallel_errors(): + return + parse_line_as_json(line_start_idx, line) def _reader_thread(self): - with open(self.src_file, "rb") as fin: + with open(self.src_file, "r") as fin: while True: - chunk = fin.read(self.chunksize) - if self._exception_buffer: - raise RuntimeError( - "Exception found in exception buffer. Probably the indexer thread ran into an error..." - ) - if not chunk: + cursor = fin.tell() + line = fin.readline() + if self._check_for_parallel_errors(): + return + if fin.tell() == self._total_num_chars: + self._queue_of_raw_lines.put((cursor, line)) break - self._chunk_queue.put(chunk) - self._chunk_queue.put(None) + line_without_newline_char = line[:-1] + self._queue_of_raw_lines.put((cursor, line_without_newline_char)) + self._queue_of_raw_lines.put(None) + + def _check_for_parallel_errors(self) -> bool: + return bool(self._exception_buffer) diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index 6e8d4d3c..f2ba6419 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -1,7 +1,11 @@ +import logging +import math +import multiprocessing +import os import pickle import warnings from pathlib import Path -from typing import IO +from typing import Callable, Iterator, List, Tuple import jq import numpy as np @@ -10,23 +14,21 @@ from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +logger = logging.getLogger(__name__) + + +class EmptySampleError(RuntimeError): + pass -class PackedDataGenerator: - # amount of bytes to represent tokens as integers. - # If the vocabulary exceeds 2^(8*`size_in_bytes`), this requires adaptation. - TOKEN_SIZE_IN_BYTES = 4 - # amount of bytes to represent number of all tokens in dataset. - # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. - # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides - HEAD_SIZE_IN_BYTES = 8 +class PackedDataGenerator: def __init__( self, src_path: Path, tokenizer: PreTrainedTokenizer, index_path: Path = None, jq_pattern: str = ".text", - max_number_of_tokens: int = None, + number_of_processes: int = os.cpu_count(), ): """ Reads in a jsonl file and the corresponding index file and packs dataset file for LLM training. @@ -38,18 +40,25 @@ def __init__( :param tokenizer: PretrainedTokenizer object, which is used to pre-tokenize the provided data in `src_path`. Tokenization is necessary to work on final lengths of token sequences. :param jq_pattern: jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed - :param max_number_of_tokens: Limit the total amount of tokens in the packed dataset. - If not specified, the whole data is packed into the dataset. """ self.src_path = src_path self.tokenizer = tokenizer + self._token_size_in_bytes = self._get_required_num_of_bytes_to_repr(self.tokenizer.vocab_size) + encoded_eos_token = self.tokenizer(self.tokenizer.eos_token)["input_ids"][0] + self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eos_token) self.jq_filter = jq.compile(jq_pattern) - self.max_tokens = max_number_of_tokens - + self._number_of_processes = number_of_processes self._reader = LargeFileLinesReader(src_path, index_path=index_path) self._total_num_of_tokens = 0 - self._curr_offset = self.HEAD_SIZE_IN_BYTES - self._index_list = [] + self._tokens_write_queue = multiprocessing.Queue() + self._exception_buffer = [] + + @staticmethod + def _get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: + return math.ceil(math.log(math.log2(int_to_get_repr), 8)) + + def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: + return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="big", signed=False) def _default_destination_path(self, destination_path: Path = None) -> Path: if destination_path is None: @@ -68,54 +77,177 @@ def run(self, dst_path: Path = None): if dst_path.exists(): raise ValueError(f"file already exists at destination path '{dst_path}'.") - encoded_eos_token = self.tokenizer(self.tokenizer.eos_token)["input_ids"][0] - encoded_eos_token_as_bytes = encoded_eos_token.to_bytes(self.TOKEN_SIZE_IN_BYTES, byteorder="big") - with dst_path.open("wb") as f: - # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) - # not possible to prepend header after determining size of data section - f.write((0).to_bytes(self.HEAD_SIZE_IN_BYTES, byteorder="big")) - - # write data section (tokens) - for idx, line in tqdm(enumerate(self._reader)): - try: - self._process_line(encoded_eos_token_as_bytes, f, line) - except ValueError: - warnings.warn(f"Encountered empty sample in line {idx} of file {self.src_path}") - except StopIteration: - break - except Exception as exception: - warnings.warn(f"could not process line: {exception=}") - - # write index - f.write(pickle.dumps(self._index_list)) - - self._update_data_length_in_pre_allocated_header(dst_path) - - def _update_data_length_in_pre_allocated_header(self, dst_path: Path): - start_of_index_in_bytes = self._index_list[-1][0] + self._index_list[-1][1] - length_of_byte_encoded_data_section = start_of_index_in_bytes - self.HEAD_SIZE_IN_BYTES - header_content = length_of_byte_encoded_data_section.to_bytes(self.HEAD_SIZE_IN_BYTES, byteorder="big") - header_content = np.frombuffer(header_content, dtype="uint8") - # write the header content to the packed dataset file - m = np.memmap(dst_path, mode="r+", offset=0, shape=(self.HEAD_SIZE_IN_BYTES,)) - m[:] = header_content[:] - - def _process_line(self, eos_token_as_bytes: bytes, f: IO, line: str): + self._exception_buffer = [] + try: + # not setting this can cause deadlocks when using hf's "FastTokenizers". See also: + # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self._launch_parallelized_workers(dst_path) + finally: + os.unsetenv("TOKENIZERS_PARALLELISM") + + if self._exception_buffer: + raise self._exception_buffer[0] + + def _launch_parallelized_workers(self, dst_path: Path): + writer = multiprocessing.Process(target=self._writer_thread(dst_path)) + writer.start() + processor_threads = [ + multiprocessing.Process(target=self._process_thread, args=(i,)) for i in range(self._number_of_processes) + ] + for p in processor_threads: + p.start() + for p in processor_threads: + p.join() + self._stop_processing() + writer.join() + + def _stop_processing(self): + self._tokens_write_queue.put(None) + + def _generator_for_tokens_to_get_written(self): + while True: + if self._check_for_parallel_errors(): + return + tokens = self._tokens_write_queue.get() + if tokens is None: + break + yield tokens + + def _check_for_parallel_errors(self) -> bool: + return bool(self._exception_buffer) + + def _writer_thread(self, dst_path: Path) -> Callable: + def writer(): + index_list = [] + with dst_path.open("wb") as f: + # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) + # not possible to prepend header after determining size of data section + f.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big")) + f.write( + self._token_size_in_bytes.to_bytes( + EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="big" + ) + ) + curr_offset = EmbeddedStreamData.HEADER_SIZE_IN_BYTES + + # write data section (tokens) + for tokens_as_bytes in tqdm( + self._generator_for_tokens_to_get_written(), desc="Processed Samples", total=len(self._reader) + ): + f.write(tokens_as_bytes) + segment_length = len(tokens_as_bytes) + index_list.append((curr_offset, segment_length)) + curr_offset += segment_length + + # write index + f.write(pickle.dumps(index_list)) + + self._update_data_length_in_pre_allocated_header(dst_path, index_list) + + return writer + + def _process_thread(self, process_id: int): + if self._check_for_parallel_errors(): + return + for idx in range(process_id, len(self._reader), self._number_of_processes): + line = self._reader[idx] + try: + self._tokens_write_queue.put(self._process_line(line)) + except EmptySampleError: + warnings.warn(f"Encountered empty sample in line {idx} of file {self.src_path}") + except Exception as exception: + warnings.warn(f"could not process line of number {idx}. Raised the following error: {exception=}") + + def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: List[Tuple[int, int]]): + start_of_index_in_bytes = index_list[-1][0] + index_list[-1][1] + length_of_byte_encoded_data_section = start_of_index_in_bytes - EmbeddedStreamData.HEADER_SIZE_IN_BYTES + data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( + EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big" + ) + with dst_path.open("rb+") as fout: + fout.seek(0) + fout.write(data_section_length_in_bytes) + + def _process_line(self, line: str) -> bytes: jq_retrieved_text = self.jq_filter.input_text(line).first() + if jq_retrieved_text is None: + raise ValueError(f"jq was not able to find anything using the expression: {self.jq_filter}") tokens = self.tokenizer(jq_retrieved_text)["input_ids"] if len(tokens) == 0: - raise ValueError("Received empty sample...") - token_idx = 0 - for token in tokens: - token_as_bytes = token.to_bytes(self.TOKEN_SIZE_IN_BYTES, byteorder="big") - f.write(token_as_bytes) - self._total_num_of_tokens += 1 - if self._total_num_of_tokens == self.max_tokens: - segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES - self._index_list.append((self._curr_offset, segment_length)) - raise StopIteration - token_idx += 1 - f.write(eos_token_as_bytes) - segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES # segment_length in bytes - self._index_list.append((self._curr_offset, segment_length)) - self._curr_offset += segment_length + raise EmptySampleError("Received empty sample...") + return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes + + +class EmbeddedStreamData: + # amount of bytes to represent number of all tokens in dataset. + # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. + # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides + DATA_SECTION_LENGTH_IN_BYTES = 8 + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4 + HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES + + def __init__(self, data_path: Path): + self._data_path = data_path + if not self._data_path.is_file(): + raise FileNotFoundError( + f"Packed Data was not found at {self._data_path}." + f"Create on in advance by using `modalities data pack_encoded_data`." + ) + + with self._data_path.open("rb") as f: + # get number of bytes in data section + data_section_length_in_bytes = f.read(self.DATA_SECTION_LENGTH_IN_BYTES) + self.data_len = int.from_bytes(data_section_length_in_bytes, byteorder="big") + + # get number of bytes for encoding a single token + f.seek(self.DATA_SECTION_LENGTH_IN_BYTES) + token_size_as_bytes = f.read(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES) + self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="big", signed=False) + + # get index + f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) + pkl_encoded_index = f.read() + self.index_base = pickle.loads(pkl_encoded_index) + + # initialize memmapped data section + self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) + + +def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): + if target_file.exists(): + raise FileExistsError(f'Target File at "{target_file}" exists!') + data_len = sum(d.data_len for d in stream_data) + assert len({d.token_size_in_bytes for d in stream_data}) == 1, ( + "Found different token representation sizes. This could indicate the usage of different tokenizers. " + "Not supported!" + ) + token_size_in_bytes = stream_data[0].token_size_in_bytes + + num_data_chunks = sum(math.ceil(d.data_len / chunk_size) for d in stream_data) + data_stream_generator = (d.data[i : i + chunk_size] for d in stream_data for i in range(0, d.data_len, chunk_size)) + + num_entries = sum(len(d.index_base) for d in stream_data) + + def index_stream_generator() -> Iterator[Tuple[int, int]]: + curr_offset = 0 + for embedded_stream_data in stream_data: + for entry_offset, segment_length in embedded_stream_data.index_base: + yield entry_offset + curr_offset, segment_length + curr_offset += embedded_stream_data.data_len + curr_offset -= embedded_stream_data.HEADER_SIZE_IN_BYTES + + with target_file.open("wb") as fout: + fout.write(data_len.to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big")) + fout.write( + token_size_in_bytes.to_bytes(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="big") + ) + for data_chunk in tqdm(data_stream_generator, total=num_data_chunks, desc="Writing Data Chunks..."): + fout.write(data_chunk) + + joint_index = [entry for entry in tqdm(index_stream_generator(), total=num_entries, desc="Concatenating Index")] + pickled_index = pickle.dumps(joint_index) + pickled_index_as_chunks = (pickled_index[i : i + chunk_size] for i in range(0, len(pickled_index), chunk_size)) + num_index_chunks = math.ceil(len(pickled_index) / chunk_size) + for index_chunk in tqdm(pickled_index_as_chunks, total=num_index_chunks, desc="Writing Index Chunks..."): + fout.write(index_chunk) diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index a3c6d3f4..695cc2a7 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -2,15 +2,14 @@ from torch.utils.data import Dataset, Sampler from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t - -from modalities.dataloader.samplers import ResumableBatchSampler +from torch.utils.data.sampler import BatchSampler class LLMDataLoader(DataLoader[T_co]): def __init__( self, dataloader_tag: str, - batch_sampler: ResumableBatchSampler, + batch_sampler: BatchSampler, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, @@ -49,19 +48,24 @@ def __init__( ) self._dataloader_tag = dataloader_tag + self._batch_size = batch_sampler.batch_size @property def dataloader_tag(self) -> str: return self._dataloader_tag @property - def sampler_batch_size(self) -> int: + def batch_size(self) -> int: # The parent Dataloader class has already a batch_size property defined which is originally used # when the batch_sampler is not specified. Since the LLMDataLoader enforces to always use a BatchSampler, - # we defined the property sampler_batch_size to return the actual batch size used in the dataloder. + # we defined/ override the property batch_size to return the actual batch size used in the dataloder. # BatchSampler is required, as we must seek forward in the dataloder during a warm start and # we don't want to load all the data during the fast-forward. - return self.batch_sampler.sampler_batch_size + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + self._batch_size = value @property def fast_forward_sample_id(self) -> int: @@ -70,7 +74,7 @@ def fast_forward_sample_id(self) -> int: Returns: int: fast forward sample id """ - return self.sampler_batch_size * self.batch_sampler.start_index + return self.batch_size * self.batch_sampler.start_index @property def fast_forward_batch_id(self) -> int: @@ -83,15 +87,15 @@ def fast_forward_batch_id(self) -> int: class RepeatingDataLoader(LLMDataLoader[T_co]): - def __init__(self, data_loader: LLMDataLoader[T_co], reshuffle_after_epoch: bool = False): + def __init__(self, dataloader: LLMDataLoader[T_co], reshuffle_after_epoch: bool = False): """Wraps an iterator to allow for infinite iteration. This is especially useful for DataLoader types that we wish to automatically restart upon completion. Args: loader (iterator): The data loader to repeat. """ - self.data_loader = data_loader - self.data_iter = iter(self.data_loader) + self.dataloader = dataloader + self.data_iter = iter(self.dataloader) self.current_epoch = 0 self.reshuffle_after_epoch = reshuffle_after_epoch @@ -102,24 +106,24 @@ def __next__(self): try: batch = next(self.data_iter) except StopIteration: - if self.data_loader.sampler is not None: + if self.dataloader.sampler is not None: # In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating # the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. # Otherwise, the same ordering will be always used. See discussion: # https://discuss.pytorch.org/t/why-is-sampler-set-epoch-epoch-needed-for-distributedsampler/149672 self.current_epoch += 1 - self.data_loader.sampler.set_epoch(self.current_epoch) - self.data_iter = iter(self.data_loader) + self.dataloader.sampler.set_epoch(self.current_epoch) + self.data_iter = iter(self.dataloader) batch = next(self.data_iter) return batch @property def dataloader_tag(self) -> str: - return self.data_loader._dataloader_tag + return self.dataloader._dataloader_tag @property - def sampler_batch_size(self) -> int: - return self.data_loader.batch_sampler.batch_size + def batch_size(self) -> int: + return self.dataloader.batch_sampler.batch_size @property def fast_forward_sample_id(self) -> int: @@ -128,7 +132,7 @@ def fast_forward_sample_id(self) -> int: Returns: int: fast forward sample id """ - return self.data_loader.sampler_batch_size * self.batch_sampler.start_index + return self.dataloader.batch_size * self.batch_sampler.start_index @property def fast_forward_batch_id(self) -> int: @@ -137,4 +141,4 @@ def fast_forward_batch_id(self) -> int: Returns: int: fast forward batch id """ - return self.data_loader.batch_sampler.start_index + return self.dataloader.batch_sampler.start_index diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 225a4583..09606415 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -1,82 +1,34 @@ +from typing import Callable, Optional + +from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset -from modalities.config.config import DataLoaderConfig, DatasetConfig from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset from modalities.dataloader.samplers import ResumableBatchSampler -from modalities.resolver_register import ResolverRegister - - -class OpenGPTXDatasetWrapper(Dataset): - def __init__(self, open_gptx_dataset: OpenGPTXMMapDataset, num_samples: int) -> None: - super().__init__() - self.open_gptx_dataset = open_gptx_dataset - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx: int): - if self.num_samples > idx: - return self.open_gptx_dataset.__getitem__(idx) - else: - raise ValueError("num_samples <= idx") class DataloaderFactory: @staticmethod def get_dataloader( - resolvers: ResolverRegister, config: DataLoaderConfig, skip_num_batches: int = 0 + dataloader_tag: str, + dataset: Dataset, + batch_sampler: BatchSampler, + collate_fn: Callable, + num_workers: int, + pin_memory: bool, + shuffle: bool, + skip_num_batches: Optional[int] = 0, ) -> LLMDataLoader: - # TODO: replace this with dynamic nested object instantiation. (More details: Different Dataloaders require - # different objects in their constructors. the resolvers should be able to provide the necessary complex - # objects automatically, without us manually creating this complex factory.) - additional_init_payload = {} - if hasattr(config.config.dataset.config, "tokenizer"): - tokenizer = resolvers.build_component_by_config(config=config.config.dataset.config.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - additional_init_payload.update(tokenizer=tokenizer) - - dataset = resolvers.build_component_by_config( - config=config.config.dataset, extra_kwargs=additional_init_payload - ) - - # BUG: Sometimes the dataset genereated by the OpenGPTXMMap implementation has too many samples. - # This is a workaround to fix the dataset to the size, as specified in the config! - # TODO: Fix the OpenGPTX implementation and get rid of this hack. - if isinstance(config.config.dataset.config, DatasetConfig.OpenGPTXMMapDatasetConfig): - dataset = OpenGPTXDatasetWrapper( - open_gptx_dataset=dataset, num_samples=config.config.dataset.config.num_samples - ) - - collator = resolvers.build_component_by_config(config=config.config.collate_fn) - sampler = resolvers.build_component_by_config( - config=config.config.batch_sampler.config.sampler, extra_kwargs=dict(dataset=dataset) - ) - - batch_sampler = resolvers.build_component_by_config( - config=config.config.batch_sampler, - extra_kwargs=dict( - sampler=sampler, - ), - ) - - resumable_batch_sampler = ResumableBatchSampler( - start_index=skip_num_batches, underlying_batch_sampler=batch_sampler - ) - - dataloader = resolvers.build_component_by_config( - config=config, - extra_kwargs=dict( - dataset=dataset, - batch_sampler=resumable_batch_sampler, - collate_fn=collator, - ), + batch_sampler = ResumableBatchSampler(start_index=skip_num_batches, underlying_batch_sampler=batch_sampler) + + dataloader = LLMDataLoader( + dataloader_tag=dataloader_tag, + batch_sampler=batch_sampler, + dataset=dataset, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, ) - # TODO we should have this check rather in the gym. Here, it is clear that - # we are using the LLMDataLoader - assert isinstance( - dataloader, LLMDataLoader - ), f"Dataloader Class must use the {LLMDataLoader.__name__}-Interface" return dataloader diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 8e7a4c3b..ef0ae2ad 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,7 +1,5 @@ from __future__ import annotations -import os -import pickle from pathlib import Path from typing import List, Optional, Tuple @@ -12,6 +10,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer from ..dataloader.large_file_lines_reader import LargeFileLinesReader +from .create_packed_data import EmbeddedStreamData class Dataset(TorchdataSet): @@ -70,98 +69,72 @@ def __getitem__(self, idx: int) -> BatchEncoding: class PackedMemMapDatasetBase(Dataset): - INT_SIZE_IN_BYTES = 4 - HEADER_SIZE_IN_BYTES = 8 + DATA_SECTION_LENGTH_IN_BYTES = EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES + HEADER_SIZE_IN_BYTES = EmbeddedStreamData.HEADER_SIZE_IN_BYTES + np_dtype_from_num_bytes = { + 1: np.dtype(np.uint8).newbyteorder(">"), + 2: np.dtype(np.uint16).newbyteorder(">"), + 4: np.dtype(np.uint32).newbyteorder(">"), + 8: np.dtype(np.uint64).newbyteorder(">"), + } def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): """ Base class for packed memmapped datasets. The underlying dataset file has the structure: | header | data | index | - The header contains information about the length of the subsequent data sequence. The index contains - the tuple information (start, end) in terms of byte positions. + The header contains information about the length of the subsequent data sequence and the amount of bytes + required to represent tokens in the data section. The index contains the tuple information (start, end) in terms + of byte positions. :param raw_data_path: Path to a packed binary file (*.pbin). - Use `modalities create_packed_data` to create one based on a jsonl-file. + Use `modalities data pack_encoded_data` to create one based on a jsonl-file. :param block_size: alias for max sequence length. The amount of tokens the model can handle. :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. TODO: If this setting should support multi-modal features using separately encoded inputs, this needs to get replaced with a list of sample keys! """ super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) - if not self.raw_data_path.is_file(): - raise FileNotFoundError( - f"Packed Data was not found at {self.raw_data_path}." - f"Create on in advance by using `modalities create_packed_data`." + self._embedded_stream_data = EmbeddedStreamData(raw_data_path) + self._token_size_in_bytes = self._embedded_stream_data.token_size_in_bytes + try: + self._token_dtype = self.np_dtype_from_num_bytes[self._token_size_in_bytes] + except KeyError: + raise RuntimeError( + f"Encountered a required token representation with {self._token_size_in_bytes}," + " which is not supported. Consider using a smaller vocabulary." ) + self._index = self._generate_packing_index() - # get number of total bytes in file - with self.raw_data_path.open("rb") as f: - f.seek(0, os.SEEK_END) - self.total_bytes = f.tell() - f.seek(0) - - # get number of bytes in data section - self.data_len = np.memmap( - self.raw_data_path, - mode="r", - offset=0, - shape=(self.HEADER_SIZE_IN_BYTES,), - ).view(f"S{self.HEADER_SIZE_IN_BYTES}") - self.data_len = int.from_bytes(self.data_len, byteorder="big") - - # get index - self.index_base = np.memmap( - self.raw_data_path, - mode="r", - offset=self.HEADER_SIZE_IN_BYTES + self.data_len, - shape=(self.total_bytes - self.data_len - self.HEADER_SIZE_IN_BYTES,), - ).view(f"S{self.total_bytes-self.data_len-self.HEADER_SIZE_IN_BYTES}") - self.index_base = pickle.loads(self.index_base) - - -class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): - def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): - """ - PackedMemMapDatasetContinuous iterates through the data in block_size sized chunks, - irrespective of the samples' start and end position, as defined in the index. - Therefore, for this datset, the index is irrelevant. - - :param raw_data_path: Path to a packed binary file (*.pbin). - Use `modalities create_packed_data` to create one based on a jsonl-file. - :param block_size: alias for max sequence length. The amount of tokens the model can handle. - :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. - TODO: If this setting should support multi-modal features using separately encoded inputs, - this needs to get replaced with a list of sample keys! - """ - super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) - - # get number of total tokens in file - total_tokens = self.data_len // self.INT_SIZE_IN_BYTES - self._num_samples = total_tokens // self.block_size + def _generate_packing_index(self) -> List[Tuple[int, int]]: + raise NotImplementedError def __len__(self) -> int: - return self._num_samples + return len(self._index) def __getitem__(self, idx: int) -> BatchEncoding: self._check_if_inbounds(idx) - tokens_as_byte_strings = np.memmap( - self.raw_data_path, - mode="r", - offset=self.HEADER_SIZE_IN_BYTES + idx * self.INT_SIZE_IN_BYTES * self.block_size, - shape=(self.INT_SIZE_IN_BYTES * self.block_size,), - ).view(f"S{self.INT_SIZE_IN_BYTES}") - tokens = [int.from_bytes(token, byteorder="big") for token in tokens_as_byte_strings] + offset, length = self._index[idx] + tokens = np.frombuffer(self._embedded_stream_data.data, dtype=self._token_dtype, count=length, offset=offset) return BatchEncoding(data={self.sample_key: tokens}) +class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): + def _generate_packing_index(self) -> List[Tuple[int, int]]: + # get number of total tokens in file + total_tokens = self._embedded_stream_data.data_len // self._token_size_in_bytes + num_samples = total_tokens // self.block_size + return [(i * self.block_size * self._token_size_in_bytes, self.block_size) for i in range(num_samples)] + + class PackedMemMapDatasetMegatron(PackedMemMapDatasetBase): - def generate_megatron_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> List[Tuple[int, int]]: index = [] curr_offset = self.HEADER_SIZE_IN_BYTES curr_len = 0 - block_size_in_bytes = self.block_size * self.INT_SIZE_IN_BYTES - for segment_offset, segment_len in tqdm(self.index_base): - # When the sum of of the length of the current previously seen samples doesn't + block_size_in_bytes = self.block_size * self._token_size_in_bytes + for segment_offset, segment_len in tqdm(self._embedded_stream_data.index_base): + # When the sum of the length of the current previously seen samples doesn't # exceed block_size_in_bytes, we add the current segment length to the previous # ones and continue. if curr_len + segment_len < block_size_in_bytes: @@ -169,14 +142,14 @@ def generate_megatron_index(self) -> List[Tuple[int, int]]: # If the previous and current length equals block_size_in_bytes, we add the starting index # and the total sequences length to the index list as a new sample. elif curr_len + segment_len == block_size_in_bytes: - index.append((curr_offset, block_size_in_bytes)) + index.append((curr_offset, self.block_size)) curr_len = 0 curr_offset += block_size_in_bytes # Else case is executed when the current and previous segment length exceed the block_size. # In this case we set the starting point of the next sample to the end of the current sample. # This way, the start of a sample is never in the middle of a sentence. else: - index.append((curr_offset, block_size_in_bytes)) + index.append((curr_offset, self.block_size)) if segment_len > block_size_in_bytes: curr_offset += block_size_in_bytes curr_len = 0 @@ -184,30 +157,3 @@ def generate_megatron_index(self) -> List[Tuple[int, int]]: curr_offset = segment_offset curr_len = segment_len return index - - def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): - """ - :param raw_data_path: Path to a packed binary file (*.pbin). - Use `modalities create_packed_data` to create one based on a jsonl-file. - :param block_size: alias for max sequence length. The amount of tokens the model can handle. - :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. - TODO: If this setting should support multi-modal features using separately encoded inputs, - this needs to get replaced with a list of sample keys! - """ - super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) - self._index = self.generate_megatron_index() - - def __len__(self) -> int: - return len(self._index) - - def __getitem__(self, idx: int) -> BatchEncoding: - self._check_if_inbounds(idx) - offset, length = self._index[idx] - tokens_as_byte_strings = np.memmap( - self.raw_data_path, - mode="r", - offset=offset, - shape=(length,), - ).view(f"S{self.INT_SIZE_IN_BYTES}") - tokens = [int.from_bytes(token, byteorder="big") for token in tokens_as_byte_strings] - return BatchEncoding(data={self.sample_key: tokens}) diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py new file mode 100644 index 00000000..157e98d0 --- /dev/null +++ b/src/modalities/dataloader/dataset_factory.py @@ -0,0 +1,85 @@ +from pathlib import Path +from typing import Optional + +from pydantic import FilePath +from torch.utils.data.dataset import Dataset +from transformers import PreTrainedTokenizer + +from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron +from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset + + +class OpenGPTXDatasetWrapper(Dataset): + def __init__(self, open_gptx_dataset: OpenGPTXMMapDataset, num_samples: int) -> None: + super().__init__() + self.open_gptx_dataset = open_gptx_dataset + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx: int): + if self.num_samples > idx: + return self.open_gptx_dataset.__getitem__(idx) + else: + raise ValueError("num_samples <= idx") + + +class DatasetFactory: + @staticmethod + def get_mem_map_dataset( + raw_data_path: Path, + block_size: int, + tokenizer: PreTrainedTokenizer, + sample_key: str, + index_path: Optional[Path] = None, + jq_pattern: str = ".text", + ) -> MemMapDataset: + # TODO this was part of the old Dataloader implementation. + # we need to check if this is actually wanted generally. + tokenizer.pad_token = tokenizer.eos_token + + dataset = MemMapDataset( + raw_data_path=raw_data_path, + block_size=block_size, + tokenizer=tokenizer, + sample_key=sample_key, + index_path=index_path, + jq_pattern=jq_pattern, + ) + return dataset + + @staticmethod + def get_packed_mem_map_dataset_continuous( + raw_data_path: Path, block_size: int, sample_key: str + ) -> PackedMemMapDatasetContinuous: + dataset = PackedMemMapDatasetContinuous( + raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key + ) + return dataset + + @staticmethod + def get_packed_mem_map_dataset_megatron( + raw_data_path: Path, block_size: int, sample_key: str + ) -> PackedMemMapDatasetMegatron: + dataset = PackedMemMapDatasetMegatron(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) + return dataset + + @staticmethod + def get_open_gptx_mmap_dataset( + sample_key: str, + path: FilePath, + sequence_len: int, + num_samples: int, + seed: int = 47, + ) -> OpenGPTXMMapDataset: + # part of open gptx + dataset = OpenGPTXMMapDataset( + sample_key=sample_key, path=path, sequence_len=sequence_len, num_samples=num_samples, seed=seed + ) + + # BUG: Sometimes the dataset genereated by the OpenGPTXMMap implementation has too many samples. + # This is a workaround to fix the dataset to the size, as specified in the config! + # TODO: Fix the OpenGPTX implementation and get rid of this hack. + dataset_wrapped = OpenGPTXDatasetWrapper(open_gptx_dataset=dataset, num_samples=num_samples) + return dataset_wrapped diff --git a/src/modalities/dataloader/large_file_lines_reader.py b/src/modalities/dataloader/large_file_lines_reader.py index d548c1f4..3d45072b 100644 --- a/src/modalities/dataloader/large_file_lines_reader.py +++ b/src/modalities/dataloader/large_file_lines_reader.py @@ -1,11 +1,8 @@ import pickle -import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import List -import numpy as np - class BaseReader(ABC): @abstractmethod @@ -17,7 +14,6 @@ def __getitem__(self, key: int | slice) -> str | List[str]: raise NotImplementedError -# TODO: benchmark tokenized version vs plain text version (regarding speed and storage consumption) class LargeFileLinesReader(BaseReader): def __init__(self, raw_data_path: Path, index_path: Path = None): """ @@ -33,7 +29,7 @@ def __init__(self, raw_data_path: Path, index_path: Path = None): if not self.raw_data_path.is_file(): raise FileNotFoundError("Raw data file does not exist") if not self.index_path.is_file(): - raise FileNotFoundError("Index file does not exist. Use `modalities create_memmap_index` to create one.") + raise FileNotFoundError("Index file does not exist. Use `modalities data create_raw_index` to create one.") with self.index_path.open("rb") as f: self.index = pickle.load(f) @@ -56,21 +52,6 @@ def __getitem__(self, key: int | slice) -> str | List[str]: return self.__read_from_raw_file(offset, sample_length_in_bytes) def __read_from_raw_file(self, offset: int, sample_length_in_bytes: int) -> str: - def safe_decoder(byte_char): - try: - # TODO: verify why iso-8859-1 was necessary here in the path. - # Maybe there was an issue with the actual loading of the jsonl-files - c = byte_char.decode("utf8") - except Exception as exception: - c = "" - warnings.warn(f'Encountered invalid char: "{byte_char}".') - warnings.warn(f"Encountered problem: {exception}") - return c - - string = ( - np.memmap(self.raw_data_path, mode="r", offset=offset, shape=(sample_length_in_bytes,)).view("S1").tolist() - ) - decoded_string = [] - for c in string: - decoded_string.append(safe_decoder(c)) - return "".join(decoded_string) + f = self.raw_data_path.open() + f.seek(offset) + return f.read(sample_length_in_bytes) diff --git a/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py b/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py index 793649dc..43439377 100644 --- a/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py +++ b/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py @@ -417,23 +417,3 @@ def __getitem__(self, idx: int): # Sample is of length sequence_len + 1 because target toke is part of the sample return {self.sample_key: np.array(sample, dtype=np.int64)} - - -class OpenGPTXMMapDatasetFactory: - @staticmethod - def create_dataset(num_samples: int, path: FilePath, sample_key: str, sequence_len: int) -> OpenGPTXMMapDataset: - # dataset_dir = path.parents[0] - # dataset_filename_prefix = path.stem - # text_dataset = make_dataset(path=dataset_dir.joinpath(dataset_filename_prefix)) - - # instances = OpenGPTXDataset( - # sample_key=sample_key, - # text_dataset=text_dataset, - # doc_idx=np.arange(0, len(text_dataset)), - # dataset_dir=dataset_dir, - # num_samples=num_samples, - # dataset_name=dataset_filename_prefix, - # sequence_len=sequence_len, - # ) - # return instances - pass diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index af3a4aa2..c5ab2699 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -13,6 +13,8 @@ def __init__(self, start_index: int, underlying_batch_sampler: BatchSampler): self.start_index = start_index self.underlying_batch_sampler = underlying_batch_sampler + # NOTE: we are only iterating ove the indices not the actual data + # so this is relatively cheap self.indices = list(iter(self.underlying_batch_sampler)) def __iter__(self): @@ -22,5 +24,5 @@ def __len__(self): return len(self.indices) - self.start_index @property - def sampler_batch_size(self) -> int: + def batch_size(self) -> int: return self.underlying_batch_sampler.batch_size diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 78adfe1e..a2823a2a 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -45,11 +45,11 @@ def evaluate( ) -> Dict[str, EvaluationResultBatch]: result_dict: Dict[str, EvaluationResultBatch] = {} model.eval() + + device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") + for data_loader in data_loaders: - if torch.cuda.is_available(): - cummulated_loss = torch.zeros(3).to(torch.device(self.local_rank)) - else: - cummulated_loss = torch.zeros(3).to("cpu") + cumulated_loss = torch.zeros(3).to(device) Evaluator._publish_progress( batch_progress_publisher=self.batch_progress_publisher, @@ -66,13 +66,13 @@ def evaluate( loss_fun=loss_fun, ) - cummulated_loss[0] += batch_loss.item() # sum up batch loss - cummulated_loss[1] += len(batch) - batch_length_tensor = torch.tensor(len(batch)).to(torch.device(self.local_rank)) + cumulated_loss[0] += batch_loss.item() # sum up batch loss + cumulated_loss[1] += len(batch) + batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) local_dataset_sample_id = Evaluator._get_local_sample_id( - batch_id=batch_id, batch_size=data_loader.sampler_batch_size + batch_id=batch_id, batch_size=data_loader.batch_size ) global_dataset_sample_id = local_sample_id_to_global_sample_id(local_dataset_sample_id) @@ -85,22 +85,20 @@ def evaluate( ) # TODO: insert reducer from outside so Evaluator is independent of FSDP total_loss = Reducer.reduce( - tensor=cummulated_loss, + tensor=cumulated_loss, operation=dist.ReduceOp.SUM, post_processing_fun=lambda t: t[0] / t[1], ) - foward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to( - torch.device(self.local_rank) - ) + forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) thoughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=foward_backward_time + key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time ) synced_num_samples = thoughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES) - synced_foward_backward_time = thoughput_aggregator.get_all_reduced_value( + synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX ) - num_samples_per_second = synced_num_samples / synced_foward_backward_time + num_samples_per_second = synced_num_samples / synced_forward_backward_time evaluation_result = EvaluationResultBatch( losses={loss_fun.tag: total_loss}, diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index c5e5e3a2..07e344d5 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -15,4 +15,4 @@ class RunningEnvError(Exception): class TimeRecorderStateError(Exception): - pass \ No newline at end of file + pass diff --git a/src/modalities/logging_broker/message_broker.py b/src/modalities/logging_broker/message_broker.py index d5f4aec2..7b38e58f 100644 --- a/src/modalities/logging_broker/message_broker.py +++ b/src/modalities/logging_broker/message_broker.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod from collections import defaultdict +from typing import Dict, List + from modalities.logging_broker.messages import Message, MessageTypes from modalities.logging_broker.subscriber import MessageSubscriberIF -from typing import Dict, List class MessageBrokerIF(ABC): """Interface for message broker objects.""" + @abstractmethod def add_subscriber(self, subscription: MessageTypes, subscriber: MessageSubscriberIF): raise NotImplementedError @@ -18,6 +20,7 @@ def distribute_message(self, message: Message): class MessageBroker(MessageBrokerIF): """The MessageBroker sends notifications to its subscribers.""" + def __init__(self) -> None: self.subscriptions: Dict[MessageTypes, List[MessageSubscriberIF]] = defaultdict(list) diff --git a/src/modalities/logging_broker/publisher.py b/src/modalities/logging_broker/publisher.py index 34ff834b..28cc27de 100644 --- a/src/modalities/logging_broker/publisher.py +++ b/src/modalities/logging_broker/publisher.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar -from modalities.logging_broker.message_broker import Message, MessageBroker +from modalities.logging_broker.message_broker import Message, MessageBroker from modalities.logging_broker.messages import MessageTypes T = TypeVar("T") @@ -15,6 +15,7 @@ def publish_message(self, payload: T, message_type: MessageTypes): class MessagePublisher(MessagePublisherIF[T]): """The MessagePublisher sends messages through a message broker.""" + def __init__( self, message_broker: MessageBroker, diff --git a/src/modalities/logging_broker/subscriber.py b/src/modalities/logging_broker/subscriber.py index 7e965b75..6b4e5c2d 100644 --- a/src/modalities/logging_broker/subscriber.py +++ b/src/modalities/logging_broker/subscriber.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar + from modalities.logging_broker.messages import Message T = TypeVar("T") @@ -11,4 +12,3 @@ class MessageSubscriberIF(ABC, Generic[T]): @abstractmethod def consume_message(self, message: Message[T]): raise NotImplementedError - diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index b2965725..ff558a4d 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Dict, Optional import rich from rich.console import Group @@ -7,9 +7,10 @@ import wandb from modalities.batch import EvaluationResultBatch +from modalities.config.config import WandbMode from modalities.logging_broker.messages import Message from modalities.logging_broker.subscriber import MessageSubscriberIF -from modalities.config.config import AppConfig, WandbConfig + class DummyResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): def consume_message(self, message: Message[EvaluationResultBatch]): @@ -49,16 +50,23 @@ def consume_message(self, message: Message[EvaluationResultBatch]): class WandBEvaluationResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): """A subscriber object for the WandBEvaluationResult observable.""" - def __init__(self, num_ranks: int, project: str, experiment_id: str, mode: WandbConfig.WandbMode, dir: Path, - experiment_config: Optional[AppConfig] = None) -> None: + def __init__( + self, + project: str, + experiment_id: str, + mode: WandbMode, + directory: Path, + experiment_config: Optional[Dict] = None, + ) -> None: super().__init__() - self.num_ranks = num_ranks # experiment_config_json = None # if experiment_config is not None: # experiment_config_json = experiment_config.model_dump(mode="json") - wandb.init(project=project, name=experiment_id, mode=mode.value.lower(), dir=dir, config=experiment_config) + wandb.init( + project=project, name=experiment_id, mode=mode.value.lower(), dir=directory, config=experiment_config + ) def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" @@ -82,6 +90,4 @@ def consume_message(self, message: Message[EvaluationResultBatch]): f"{eval_result.dataloader_tag} {metric_key}": metric_values for metric_key, metric_values in eval_result.throughput_metrics.items() } - wandb.log( - data=throughput_metrics, step=eval_result.global_train_sample_id + 1 - ) + wandb.log(data=throughput_metrics, step=eval_result.global_train_sample_id + 1) diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py new file mode 100644 index 00000000..3d63cdad --- /dev/null +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -0,0 +1,76 @@ +from pathlib import Path +from typing import Dict, List + +from modalities.config.config import WandbMode +from modalities.dataloader.dataloader import LLMDataLoader +from modalities.logging_broker.subscriber_impl.batch_progress_subscriber import ( + DummyProgressSubscriber, + RichProgressSubscriber, +) +from modalities.logging_broker.subscriber_impl.results_subscriber import ( + DummyResultSubscriber, + RichResultSubscriber, + WandBEvaluationResultSubscriber, +) + + +class ProgressSubscriberFactory: + @staticmethod + def get_rich_progress_subscriber( + train_dataloader: LLMDataLoader, + eval_dataloaders: List[LLMDataLoader], + world_size: int, + global_num_seen_samples: int, + local_rank: int, + ) -> RichProgressSubscriber: + if local_rank == 0: + skip_num_local_train_batches = global_num_seen_samples // world_size // train_dataloader.batch_size + train_split_num_samples = { + train_dataloader.dataloader_tag: (len(train_dataloader) + skip_num_local_train_batches) + * world_size + * train_dataloader.batch_size + } + + eval_splits_num_samples = { + dataloader.dataloader_tag: len(dataloader) * world_size * dataloader.batch_size + for dataloader in eval_dataloaders + } + + subscriber = RichProgressSubscriber(world_size, train_split_num_samples, eval_splits_num_samples) + else: + subscriber = ProgressSubscriberFactory.get_dummy_progress_subscriber() + return subscriber + + @staticmethod + def get_dummy_progress_subscriber() -> DummyProgressSubscriber: + return DummyProgressSubscriber() + + +class ResultsSubscriberFactory: + @staticmethod + def get_rich_result_subscriber(num_ranks: int, local_rank: int) -> RichResultSubscriber: + if local_rank == 0: + return RichResultSubscriber(num_ranks) + else: + return ResultsSubscriberFactory.get_dummy_result_subscriber() + + @staticmethod + def get_dummy_result_subscriber() -> DummyResultSubscriber: + return DummyResultSubscriber() + + @staticmethod + def get_wandb_result_subscriber( + local_rank: int, + project: str, + experiment_id: str, + mode: WandbMode, + directory: Path = None, + experiment_config: Dict = None, + ) -> WandBEvaluationResultSubscriber: + if local_rank == 0 and (mode == WandbMode.ONLINE or mode == WandbMode.OFFLINE): + result_subscriber = WandBEvaluationResultSubscriber( + project, experiment_id, mode, directory, experiment_config + ) + else: + result_subscriber = ResultsSubscriberFactory.get_dummy_result_subscriber() + return result_subscriber diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index c3144597..bf7b4251 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -41,3 +41,83 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: # Flatten the tokens loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss + + +def nce_loss( + embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float +) -> torch.Tensor: + """ + This implementation calculates the noise contrastive estimation loss between embeddings of two different modalities + Implementation slightly adapted from https://arxiv.org/pdf/1912.06430.pdf, https://github.com/antoine77340/MIL-NCE_HowTo100M + changes include adding a temperature value and the choice of calculating asymmetric loss w.r.t. one modality + This implementation is adapted to contrastive loss from CoCa model https://arxiv.org/pdf/2205.01917.pdf + + Args: + embedding1 (torch.Tensor): embeddings from modality 1 of size batch_size x embed_dim. + embedding2 (torch.Tensor): embeddings from modality 2 of size batch_size x embed_dim. + device (torch.device): torch device for calculating loss. + is_asymmetric (bool): boolean value to specify if the loss is calculated in one direction or both directions. + temperature (float): temperature value for regulating loss. + + Returns: + torch.Tensor: loss tensor. + """ + # calculating the similarity matrix of size (batch_size x batch_size) + sim_matrix = torch.matmul(embedding1, embedding2.t()) / temperature + # numerator of loss: using similarity scores for all positive pairs (e.g., image and its caption) + numerator = sim_matrix * torch.eye(sim_matrix.shape[0], device=device) + numerator = numerator.sum(dim=0).view(sim_matrix.shape[0], -1) + numerator = torch.logsumexp(numerator, dim=1) + if is_asymmetric: + # denominator of loss: using all similarity scores for all pairs (positive and negative) + denominator = torch.logsumexp(sim_matrix, dim=1) + else: + # calculate bidirectional loss + numerator *= 2 + denominator = torch.logsumexp(sim_matrix, dim=1) + torch.logsumexp(sim_matrix.t(), dim=1) + return torch.mean(denominator - numerator) # calculated in log space + + +class NCELoss(Loss): + def __init__( + self, + prediction_key1: str, + prediction_key2: str, + is_asymmetric: bool = True, + temperature: float = 1.0, + tag: str = "NCELoss", + ): + """ + Noise Contrastive Estimation Loss + + Args: + prediction_key1 (str): key to access embedding 1. + prediction_key2 (str): key to access embedding 2. + is_asymmetric (bool, optional): specifies symmetric or asymmetric calculation of NCEloss. Defaults to True. + temperature (float, optional): temperature. Defaults to 1.0. + tag (str, optional): Defaults to "NCELoss". + """ + super().__init__(tag) + self.prediction_key1 = prediction_key1 + self.prediction_key2 = prediction_key2 + self.is_asymmetric = is_asymmetric + self.temperature = temperature + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + """ + Args: + forward_batch (InferenceResultBatch): data batch. + + Returns: + torch.Tensor: loss tensor. + """ + embedding1 = forward_batch.get_predictions(self.prediction_key1) + embedding2 = forward_batch.get_predictions(self.prediction_key2) + + contiguous_embedding1 = embedding1.contiguous() + contiguous_embedding2 = embedding2.contiguous() + + loss = nce_loss( + contiguous_embedding1, contiguous_embedding2, embedding1.device, self.is_asymmetric, self.temperature + ) + return loss diff --git a/src/modalities/models/gpt2/collator.py b/src/modalities/models/gpt2/collator.py index 5004842f..0f7ce515 100644 --- a/src/modalities/models/gpt2/collator.py +++ b/src/modalities/models/gpt2/collator.py @@ -1,4 +1,4 @@ -from dataclasses import field +from abc import ABC, abstractmethod from typing import Dict, List import torch @@ -6,9 +6,14 @@ from modalities.batch import DatasetBatch -class GPT2LLMCollator: +class CollateFnIF(ABC): + @abstractmethod + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + raise NotImplementedError + + +class GPT2LLMCollateFn(CollateFnIF): def __init__(self, sample_key: str, target_key: str): - self.device: torch.device = field(default_factory=lambda: torch.device("cpu")) self.sample_key = sample_key self.target_key = target_key diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index eef66c44..e326b690 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1,12 +1,12 @@ import math from enum import Enum from functools import partial -from typing import Dict +from typing import Annotated, Dict import torch import torch.nn as nn import xformers.ops as xops -from pydantic import BaseModel, confloat, conint, model_validator +from pydantic import BaseModel, Field, model_validator from torch.nn import functional as F from modalities.models.model import NNModel @@ -26,33 +26,36 @@ class ActivationType(str, Enum): class AttentionConfig(BaseModel): attention_type: AttentionType - scaling_factor: conint(ge=1) + scaling_factor: Annotated[int, Field(strict=True, ge=1)] class WeightInitailizationConfig(BaseModel): - mean: confloat(ge=0.0) - std: confloat(ge=0.0) + mean: Annotated[float, Field(strict=True, ge=0.0)] + std: Annotated[float, Field(strict=True, ge=0.0)] -class GPT2Config(BaseModel): +class GPT2LLMConfig(BaseModel): sample_key: str prediction_key: str - block_size: conint(ge=1) - vocab_size: conint(ge=1) # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: conint(ge=1) - n_head_q: conint(ge=1) - n_head_kv: conint(ge=1) - n_embd: conint(ge=1) - ffn_hidden: conint(ge=1) - dropout: confloat(ge=0.0) + block_size: Annotated[int, Field(strict=True, ge=1)] + vocab_size: Annotated[ + int, Field(strict=True, ge=1) + ] # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: Annotated[int, Field(strict=True, ge=1)] + n_head_q: Annotated[int, Field(strict=True, ge=1)] + n_head_kv: Annotated[int, Field(strict=True, ge=1)] + n_embd: Annotated[int, Field(strict=True, ge=1)] + ffn_hidden: Annotated[int, Field(strict=True, ge=1)] + + dropout: Annotated[float, Field(strict=True, ge=0.0)] bias: bool # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster attention: AttentionConfig activation: ActivationType - epsilon: confloat(ge=0.0) + epsilon: Annotated[float, Field(strict=True, ge=0.0)] weight_init: WeightInitailizationConfig @model_validator(mode="after") - def validate_sizes(self) -> "GPT2Config": + def validate_sizes(self) -> "GPT2LLMConfig": for param, param_name in zip( [self.ffn_hidden, self.vocab_size, self.n_embd], ["ffn_hidden", "vocab_size", "n_embd"] ): @@ -83,7 +86,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class CausalSelfAttention(nn.Module): def __init__( - self, n_head_q: int, n_head_kv: int, n_embd: int, attention: AttentionConfig, bias: bool, dropout: float, block_size: int + self, + n_head_q: int, + n_head_kv: int, + n_embd: int, + attention: AttentionConfig, + bias: bool, + dropout: float, + block_size: int, ): super().__init__() assert n_embd % n_head_q == 0 @@ -176,7 +186,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Block(nn.Module): +class GPT2Block(nn.Module): def __init__( self, n_embd: int, @@ -193,7 +203,13 @@ def __init__( super().__init__() self.ln_1 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon) self.attn = CausalSelfAttention( - n_head_q=n_head_q, n_head_kv=n_head_kv, n_embd=n_embd, attention=attention, bias=bias, dropout=dropout, block_size=block_size + n_head_q=n_head_q, + n_head_kv=n_head_kv, + n_embd=n_embd, + attention=attention, + bias=bias, + dropout=dropout, + block_size=block_size, ) self.ln_2 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon) @@ -245,7 +261,7 @@ def __init__( drop=nn.Dropout(dropout), h=nn.ModuleList( [ - Block( + GPT2Block( n_embd=n_embd, bias=bias, epsilon=epsilon, diff --git a/src/modalities/models/gpt2/preprocess_dataset.py b/src/modalities/models/gpt2/preprocess_dataset.py index 99afb069..e89d591e 100644 --- a/src/modalities/models/gpt2/preprocess_dataset.py +++ b/src/modalities/models/gpt2/preprocess_dataset.py @@ -1,21 +1,25 @@ +import os from itertools import chain -from datasets import load_dataset -from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config + from accelerate import Accelerator -import os +from datasets import load_dataset +from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast def main(): - def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. - # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + # We drop the small remainder, and if the total_length < block_size + # we exclude this batch and return an empty dict. We could add padding if the + # model supported it instead of this drop, you can customize this part to your needs. total_length = (total_length // block_size) * block_size # Split by chunks of max_len. - result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } result["labels"] = result["input_ids"].copy() return result diff --git a/src/modalities/models/huggingface/__init__.py b/src/modalities/models/huggingface/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/modalities/models/huggingface/__init__.py @@ -0,0 +1 @@ + diff --git a/src/modalities/models/huggingface/huggingface_models.py b/src/modalities/models/huggingface/huggingface_models.py new file mode 100644 index 00000000..4c66d46f --- /dev/null +++ b/src/modalities/models/huggingface/huggingface_models.py @@ -0,0 +1,84 @@ +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from pydantic import BaseModel +from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer + +from modalities.config.lookup_enum import LookupEnum +from modalities.models.model import NNModel + +# Huggingface Model dependencies +# +# ModuleUtilsMixin +# GenerationMixin +# PushToHubMixin +# PeftAdapterMixin +# <- PreTrainedModel +# <- LlamaPreTrainedModel The bare LLaMA Model outputting raw hidden-states without any specific head on top. +# <- LlamaModel The bare LLaMA Model outputting raw hidden-states without any specific head on top. +# <- LlamaForCausalLM +# <- LlamaForSequenceClassification The LLaMa transformer with a sequence classif. head on top (lin. layer) + + +class HuggingFaceModelTypes(LookupEnum): + AutoModelForCausalLM = AutoModelForCausalLM + AutoModelForMaskedLM = AutoModelForMaskedLM + + +class HuggingFacePretrainedModelConfig(BaseModel): + model_type: HuggingFaceModelTypes + model_name: Path + prediction_key: str + huggingface_prediction_subscription_key: str + sample_key: str + model_args: Optional[Any] = None + kwargs: Optional[Any] = None + + +class HuggingFacePretrainedModel(NNModel): + def __init__( + self, + model_type: HuggingFaceModelTypes, + model_name: str, + prediction_key: str, + huggingface_prediction_subscription_key: str, + sample_key: str, + model_args: Optional[Any] = None, + kwargs: Optional[Any] = None, + ): + super().__init__() + if model_args is None: + model_args = [] + if kwargs is None: + kwargs = {} + self.prediction_key = prediction_key + self.huggingface_prediction_subscription_key = huggingface_prediction_subscription_key + self.sample_key = sample_key + + # NOTE: If the model needs to be downloaded, it is NOT necessary to guard the access for rank 0. + # This is taken care of internally in huggingface hub see: + # https://github.com/huggingface/huggingface_hub/blob/3788f537b10c7d02149d6bf017d2ce19885f90a2/src/huggingface_hub/file_download.py#L1457 + self.huggingface_model = model_type.value.from_pretrained( + model_name, local_files_only=False, *model_args, **kwargs + ) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + output = self.huggingface_model.forward(inputs[self.sample_key]) + return {self.prediction_key: output[self.huggingface_prediction_subscription_key]} + + @property + def fsdp_block_names(self) -> List[str]: + return self.huggingface_model._no_split_modules + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b") + model = HuggingFacePretrainedModel( + model_type=HuggingFaceModelTypes.AutoModelForCausalLM, + model_name="epfl-llm/meditron-7b", + prediction_key="logits", + huggingface_prediction_subscription_key="logits", + sample_key="input_ids", + ) + print(model) diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index d00a8043..511419b9 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -1,9 +1,11 @@ from abc import abstractmethod from typing import Dict -from modalities.batch import DatasetBatch, InferenceResultBatch + import torch import torch.nn as nn +from modalities.batch import DatasetBatch, InferenceResultBatch + class NNModel(nn.Module): def __init__(self, seed: int = None): diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py new file mode 100644 index 00000000..3df388d9 --- /dev/null +++ b/src/modalities/models/model_factory.py @@ -0,0 +1,44 @@ +from pathlib import Path +from typing import List + +import torch +import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy + +from modalities.checkpointing.checkpointing import Checkpointing +from modalities.running_env.env_utils import MixedPrecisionSettings +from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory + + +class ModelFactory: + @staticmethod + def get_checkpointed_model(checkpointing: Checkpointing, checkpoint_path: Path, model: nn.Module) -> nn.Module: + wrapped_model = checkpointing.load_model_checkpoint( + file_path=checkpoint_path, + model=model, + ) + return wrapped_model + + @staticmethod + def get_fsdp_wrapped_model( + model: nn.Module, + sync_module_states: bool, + block_names: List[str], + mixed_precision_settings: MixedPrecisionSettings, + sharding_strategy: ShardingStrategy, + ) -> FSDP: + # Here, FSDPTransformerAutoWrapPolicyFactory is hardcoded and should be passed in instead! + # we also might want to have different auto wrap policies later... + fsdp_auto_wrap_factory = FSDPTransformerAutoWrapPolicyFactory(model=model, block_names=block_names) + + # model is on CPU before input to FSDP + fsdp_model = FSDP( + model, + auto_wrap_policy=fsdp_auto_wrap_factory.get_auto_wrap_policy(), + mixed_precision=mixed_precision_settings.value, + sharding_strategy=sharding_strategy, + device_id=torch.cuda.current_device(), + sync_module_states=sync_module_states, + ) + return fsdp_model diff --git a/src/__init__.py b/src/modalities/optimizers/__init__.py similarity index 100% rename from src/__init__.py rename to src/modalities/optimizers/__init__.py diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py new file mode 100644 index 00000000..e1282068 --- /dev/null +++ b/src/modalities/optimizers/optimizer_factory.py @@ -0,0 +1,21 @@ +import torch.nn as nn +from torch.optim import AdamW, Optimizer + +from modalities.checkpointing.checkpointing import Checkpointing + + +class OptimizerFactory: + @staticmethod + def get_adam_w(lr: float, wrapped_model: nn.Module): + model_parameters = wrapped_model.parameters() + optimizer = AdamW(params=model_parameters, lr=lr) + return optimizer + + @staticmethod + def get_checkpointed_optimizer( + checkpointing: Checkpointing, checkpoint_path, wrapped_model: nn.Module, optimizer: Optimizer + ): + wrapped_optimizer = checkpointing.load_optimizer_checkpoint( + file_path=checkpoint_path, optimizer=optimizer, wrapped_model=wrapped_model + ) + return wrapped_optimizer diff --git a/src/modalities/registry/__init__.py b/src/modalities/registry/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py new file mode 100644 index 00000000..40dfccfd --- /dev/null +++ b/src/modalities/registry/components.py @@ -0,0 +1,157 @@ +from dataclasses import dataclass +from typing import Type + +from pydantic import BaseModel +from torch.utils.data import BatchSampler, DistributedSampler +from transformers import GPT2TokenizerFast + +from modalities.checkpointing.checkpointing import Checkpointing +from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing +from modalities.checkpointing.checkpointing_strategies import ( + SaveEveryKStepsCheckpointingStrategy, + SaveKMostRecentCheckpointsStrategy, +) +from modalities.config.config import ( + AdamWOptimizerConfig, + BatchSamplerConfig, + CheckpointedModelConfig, + CheckpointedOptimizerConfig, + CheckpointingConfig, + CLMCrossEntropyLossConfig, + DistributedSamplerConfig, + DummyProgressSubscriberConfig, + DummyResultSubscriberConfig, + FSDPToDiscCheckpointingConfig, + FSDPWrappedModelConfig, + GPT2LLMCollateFnConfig, + GPT2TokenizerFastConfig, + LLMDataLoaderConfig, + MemMapDatasetConfig, + OpenGPTXMMapDatasetConfig, + PackedMemMapDatasetContinuousConfig, + PackedMemMapDatasetMegatronConfig, + RichProgressSubscriberConfig, + RichResultSubscriberConfig, + SaveEveryKStepsCheckpointingStrategyConfig, + SaveKMostRecentCheckpointsStrategyConfig, + WandBEvaluationResultSubscriberConfig, +) +from modalities.dataloader.dataloader_factory import DataloaderFactory +from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.logging_broker.subscriber_impl.subscriber_factory import ( + ProgressSubscriberFactory, + ResultsSubscriberFactory, +) +from modalities.loss_functions import CLMCrossEntropyLoss +from modalities.models.gpt2.collator import GPT2LLMCollateFn +from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig +from modalities.models.huggingface.huggingface_models import ( + HuggingFacePretrainedModel, + HuggingFacePretrainedModelConfig, +) +from modalities.models.model_factory import ModelFactory +from modalities.optimizers.optimizer_factory import OptimizerFactory + + +@dataclass +class ComponentEntity: + component_key: str + variant_key: str + component_type: Type + component_config_type: Type[BaseModel] + + +COMPONENTS = [ + # models + ComponentEntity("model", "gpt2", GPT2LLM, GPT2LLMConfig), + ComponentEntity( + "model", "huggingface_pretrained_model", HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig + ), + ComponentEntity("model", "checkpointed", ModelFactory.get_checkpointed_model, CheckpointedModelConfig), + ComponentEntity("model", "fsdp_wrapped", ModelFactory.get_fsdp_wrapped_model, FSDPWrappedModelConfig), + # losses + ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + # optmizers + ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), + ComponentEntity( + "optimizer", "checkpointed", OptimizerFactory.get_checkpointed_optimizer, CheckpointedOptimizerConfig + ), + # schedulers + # ComponentEntity("scheduler", "step_lr", torch.optim.lr_scheduler.StepLR, None), # TODO + # ComponentEntity("scheduler", "constant_lr", torch.optim.lr_scheduler.ConstantLR, None), # TODO + # ComponentEntity("scheduler", "onecycle_lr", torch.optim.lr_scheduler.OneCycleLR, None), # TODO + # tokenizers + ComponentEntity("tokenizer", "gpt2_tokenizer_fast", GPT2TokenizerFast, GPT2TokenizerFastConfig), + # ComponentEntity("tokenizer", "llama_tokenizer_fast", GPT2TokenizerFast, None), # TODO + # datasets + ComponentEntity("dataset", "mem_map_dataset", DatasetFactory.get_mem_map_dataset, MemMapDatasetConfig), + ComponentEntity( + "dataset", + "packed_mem_map_dataset_continuous", + DatasetFactory.get_packed_mem_map_dataset_continuous, + PackedMemMapDatasetContinuousConfig, + ), + ComponentEntity( + "dataset", + "packed_mem_map_dataset_megatron", + DatasetFactory.get_packed_mem_map_dataset_megatron, + PackedMemMapDatasetMegatronConfig, + ), + ComponentEntity( + "dataset", "open_gptx_mmap_dataset", DatasetFactory.get_open_gptx_mmap_dataset, OpenGPTXMMapDatasetConfig + ), + # samplers + ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), + # batch samplers + ComponentEntity("batch_sampler", "default", BatchSampler, BatchSamplerConfig), + # collators + ComponentEntity("collate_fn", "gpt_2_llm_collator", GPT2LLMCollateFn, GPT2LLMCollateFnConfig), + # data loaders + ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), + # ComponentEntity("data_loader", "repeating_data_loader",(RepeatingDataLoader, None), # TODO + # checkpointing + ComponentEntity("checkpointing", "default", Checkpointing, CheckpointingConfig), + # checkpointing strategies + ComponentEntity( + "checkpointing_strategy", + "save_every_k_steps_checkpointing_strategy", + SaveEveryKStepsCheckpointingStrategy, + SaveEveryKStepsCheckpointingStrategyConfig, + ), + ComponentEntity( + "checkpointing_strategy", + "save_k_most_recent_checkpoints_strategy", + SaveKMostRecentCheckpointsStrategy, + SaveKMostRecentCheckpointsStrategyConfig, + ), + # checkpointing execution + ComponentEntity( + "checkpointing_execution", "fsdp_to_disc_checkpointing", FSDPToDiscCheckpointing, FSDPToDiscCheckpointingConfig + ), + # Progress subscriber + ComponentEntity( + "progress_subscriber", + "dummy", + ProgressSubscriberFactory.get_dummy_progress_subscriber, + DummyProgressSubscriberConfig, + ), + ComponentEntity( + "progress_subscriber", + "rich", + ProgressSubscriberFactory.get_rich_progress_subscriber, + RichProgressSubscriberConfig, + ), + # Results subscriber + ComponentEntity( + "results_subscriber", "dummy", ResultsSubscriberFactory.get_dummy_result_subscriber, DummyResultSubscriberConfig + ), + ComponentEntity( + "results_subscriber", "rich", ResultsSubscriberFactory.get_rich_result_subscriber, RichResultSubscriberConfig + ), + ComponentEntity( + "results_subscriber", + "wandb", + ResultsSubscriberFactory.get_wandb_result_subscriber, + WandBEvaluationResultSubscriberConfig, + ), +] diff --git a/src/modalities/registry/registry.py b/src/modalities/registry/registry.py new file mode 100644 index 00000000..a55df227 --- /dev/null +++ b/src/modalities/registry/registry.py @@ -0,0 +1,44 @@ +from dataclasses import asdict +from typing import Dict, List, Optional, Tuple, Type + +from pydantic import BaseModel + +from modalities.registry.components import ComponentEntity + +Entity = Tuple[Type, Type[BaseModel]] + + +class Registry: + def __init__(self, components: Optional[List[ComponentEntity]] = None) -> None: + # maps component_key -> variant_key -> entity = (component, config) + self._registry_dict: Dict[str, Dict[str, Entity]] = {} + if components is not None: + for component in components: + self.add_entity(**asdict(component)) + + def add_entity( + self, component_key: str, variant_key: str, component_type: Type, component_config_type: Type[BaseModel] + ) -> None: + if component_key not in self._registry_dict: + self._registry_dict[component_key] = {} + self._registry_dict[component_key][variant_key] = (component_type, component_config_type) + + def get_component(self, component_key: str, variant_key: str) -> Type: + entity = self._get_entity(component_key, variant_key) + try: + return entity[0] + except IndexError as e: + raise ValueError(f"0 is not a valid index in registry[{component_key}][{variant_key}]") from e + + def get_config(self, component_key: str, variant_key: str) -> Type[BaseModel]: + entity = self._get_entity(component_key, variant_key) + try: + return entity[1] + except IndexError as e: + raise ValueError(f"1 is not a valid index in registry[{component_key}][{variant_key}]") from e + + def _get_entity(self, component_key: str, variant_key: str) -> Entity: + try: + return self._registry_dict[component_key][variant_key] + except KeyError as e: + raise ValueError(f"[{component_key}][{variant_key}] are not valid keys in registry") from e diff --git a/src/modalities/resolver_register.py b/src/modalities/resolver_register.py deleted file mode 100644 index 9f571efe..00000000 --- a/src/modalities/resolver_register.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import Any, Dict, List - -import torch.optim as optim -from class_resolver import ClassResolver -from pydantic import BaseModel -from torch.utils.data import BatchSampler, DataLoader, Sampler -from torch.utils.data.distributed import DistributedSampler -from transformers import PreTrainedTokenizer - -from modalities.checkpointing.checkpointing import CheckpointingExecutionIF, CheckpointingStrategyIF -from modalities.config.config import AppConfig, OptimizerTypes, SchedulerTypes -from modalities.config.lookup_types import ( - BatchSamplerTypes, - CheckpointingExectionTypes, - CheckpointingStrategyTypes, - CollatorTypes, - DataloaderTypes, - DatasetTypes, - LossTypes, - ModelTypes, - SamplerTypes, - TokenizerTypes, -) -from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.dataset import Dataset -from modalities.loss_functions import CLMCrossEntropyLoss, Loss -from modalities.models.gpt2.collator import GPT2LLMCollator -from modalities.models.gpt2.gpt2_model import GPT2LLM, NNModel -from modalities.running_env.fsdp.fsdp_running_env import FSDPRunningEnv, RunningEnv, RunningEnvTypes - - -class ResolverRegister: - def __init__(self, config: AppConfig) -> None: - self._resolver_register: Dict[str, ClassResolver] = self._create_resolver_register(config=config) - - def build_component_by_config(self, config: BaseModel, extra_kwargs: Dict = {}) -> Any: - assert ( - "type_hint" in config.model_fields.keys() - ), f"Field 'type_hint' missing but needed for initalisation in {config}" - - kwargs = {key: getattr(config.config, key) for key in config.config.model_dump().keys()} - kwargs.update(extra_kwargs) # allow override via extra_kwargs, to add nested objects - return self._build_component( - register_key=config.type_hint, - register_query=config.type_hint.name, - extra_kwargs=kwargs, - ) - - def build_component_by_key_query(self, register_key: str, type_hint: str, extra_kwargs: Dict = {}) -> Any: - return self._build_component(register_key=register_key, register_query=type_hint, extra_kwargs=extra_kwargs) - - def _build_component(self, register_key: str, register_query: str, extra_kwargs: Dict = {}): - return self._resolver_register[register_key].make( - query=register_query, - pos_kwargs=extra_kwargs, - ) - - def _find_values_with_key_in_nested_structure(self, nested_structure: Dict, key: str) -> List[Any]: - found_values = [] - for k, v in nested_structure.items(): - if k == key: - found_values.append(v) - elif isinstance(v, dict): - found_values.extend(self._find_values_with_key_in_nested_structure(v, key)) - return found_values - - def _create_resolver_register(self, config: AppConfig) -> Dict[str, ClassResolver]: - set(self._find_values_with_key_in_nested_structure(nested_structure=config.model_dump(), key="type_hint")) - resolvers = { - config.running_env.type_hint: ClassResolver( - [t.value for t in RunningEnvTypes], - base=RunningEnv, - default=FSDPRunningEnv, - ), - config.model.type_hint: ClassResolver( - [t.value for t in ModelTypes], - base=NNModel, - default=GPT2LLM, - ), - config.optimizer.type_hint: ClassResolver( - [t.value for t in OptimizerTypes], - base=optim.Optimizer, - default=optim.AdamW, - ), - config.scheduler.type_hint: ClassResolver( - [t.value for t in SchedulerTypes], - base=optim.lr_scheduler.LRScheduler, - default=optim.lr_scheduler.StepLR, - ), - config.loss.type_hint: ClassResolver( - [t.value for t in LossTypes], - base=Loss, - default=CLMCrossEntropyLoss, - ), - **{ - sampler_type: ClassResolver( - classes=[t.value for t in SamplerTypes], - base=Sampler, - default=DistributedSampler, - ) - for sampler_type in SamplerTypes - }, - **{ - batch_sampler_type: ClassResolver( - classes=[t.value for t in BatchSamplerTypes], - base=BatchSampler, - default=BatchSampler, - ) - for batch_sampler_type in BatchSamplerTypes - }, - **{ - dataloader_type: ClassResolver( - [t.value for t in DataloaderTypes], - base=DataLoader, - default=LLMDataLoader, - ) - for dataloader_type in DataloaderTypes - }, - **{ - dataset_type: ClassResolver([t.value for t in DatasetTypes], base=Dataset) - for dataset_type in DatasetTypes - }, - **{ - collator_type: ClassResolver([t.value for t in CollatorTypes], base=GPT2LLMCollator) - for collator_type in CollatorTypes - }, - **{ - tokenizer_type: ClassResolver([t.value for t in TokenizerTypes], base=PreTrainedTokenizer) - for tokenizer_type in TokenizerTypes - }, - **{ - checkpointing_strategy_type: ClassResolver( - [t.value for t in CheckpointingStrategyTypes], base=CheckpointingStrategyIF - ) - for checkpointing_strategy_type in CheckpointingStrategyTypes - }, - **{ - checkpointing_execution_type: ClassResolver( - [t.value for t in CheckpointingExectionTypes], base=CheckpointingExecutionIF - ) - for checkpointing_execution_type in CheckpointingExectionTypes - }, - } - return resolvers - - def add_resolver(self, resolver_key: str, resolver: ClassResolver): - self._resolver_register[resolver_key] = resolver diff --git a/src/modalities/running_env/cuda_env.py b/src/modalities/running_env/cuda_env.py new file mode 100644 index 00000000..e8551869 --- /dev/null +++ b/src/modalities/running_env/cuda_env.py @@ -0,0 +1,27 @@ +import os + +import torch +import torch.distributed as dist + +from modalities.config.config import ProcessGroupBackendType + + +class CudaEnv: + def __init__( + self, + process_group_backend: ProcessGroupBackendType, + ) -> None: + self.process_group_backend = process_group_backend + # TODO we might want to set this from outside via the config + self.local_rank = int(os.getenv("LOCAL_RANK", "0")) + + def __enter__(self) -> "CudaEnv": + dist.init_process_group(self.process_group_backend.value) + torch.cuda.set_device(self.local_rank) + return self + + def __exit__(self, type, value, traceback): + pass + # TODO uncomment part below + # dist.barrier() # TODO check for concurrency issues + # dist.destroy_process_group() diff --git a/src/modalities/running_env/env_utils.py b/src/modalities/running_env/env_utils.py index 04d4d05a..87a41c62 100644 --- a/src/modalities/running_env/env_utils.py +++ b/src/modalities/running_env/env_utils.py @@ -1,11 +1,11 @@ -from enum import Enum - import torch import torch.cuda.nccl as nccl import torch.distributed as dist from pkg_resources import packaging from torch.distributed.fsdp import MixedPrecision +from modalities.config.lookup_enum import LookupEnum + def has_bfloat_support(): return ( @@ -48,7 +48,7 @@ def has_bfloat_support(): ) -class MixedPrecisionSettings(Enum): +class MixedPrecisionSettings(LookupEnum): FP_16 = fpSixteen BF_16 = bfSixteen BF_16_WORKING = bfSixteen_working diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py new file mode 100644 index 00000000..634b0c12 --- /dev/null +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -0,0 +1,55 @@ +import functools +import logging +from abc import ABC, abstractmethod +from typing import Callable, List + +import torch.nn as nn +from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + +from modalities.config.lookup_enum import LookupEnum + + +class FSDPAutoWrapFactoryIF(ABC): + @abstractmethod + def get_auto_wrap_policy(self) -> Callable: + raise NotImplementedError + + +class FSDPTransformerAutoWrapPolicyFactory(FSDPAutoWrapFactoryIF): + def __init__(self, model: nn.Module, block_names: List[str]) -> None: + # TODO it's problematic that we store the model in-memory here. Might get too large in RAM... + self.model = model + self.block_names = block_names + + @staticmethod + def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str]) -> List[nn.Module]: + fsdp_block_types = [] + for cls_block_name in block_names: + # TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct + # block class. In the long-term we should implmement this ourselves in a robuster fashion. + block_type = FullyShardedDataParallelPlugin.get_module_class_from_name(model, cls_block_name) + if block_type is None: + raise ValueError(f"Could not find block with name {cls_block_name} in model") + fsdp_block_types.append(block_type) + return fsdp_block_types + + def get_auto_wrap_policy(self) -> Callable: + transformer_layer_cls = self._get_fsdp_blocks_from_block_names(model=self.model, block_names=self.block_names) + logging.info(f"Wrapped layer classes: {transformer_layer_cls}\n") + print(f"\nWrapped layer classes: {transformer_layer_cls}\n") + + if len(transformer_layer_cls) == 0: + raise ValueError("No FSDP blocks found in model") + + auto_wrapper_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + *transformer_layer_cls, + }, + ) + return auto_wrapper_policy + + +class FSDPAutoWrapFactoryTypes(LookupEnum): + FSDPTransformerAutoWrapPolicyFactory = FSDPTransformerAutoWrapPolicyFactory diff --git a/src/modalities/running_env/fsdp/fsdp_running_env.py b/src/modalities/running_env/fsdp/fsdp_running_env.py deleted file mode 100644 index 434444bf..00000000 --- a/src/modalities/running_env/fsdp/fsdp_running_env.py +++ /dev/null @@ -1,111 +0,0 @@ -import functools -from enum import Enum -from typing import Type - -import torch -import torch.distributed as dist -import torch.nn as nn -from pydantic import BaseModel, ValidationError, validator -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - -from modalities.config.lookup_types import LookupEnum -from modalities.config.types import ProcessGroupBackendType -from modalities.models.gpt2.gpt2_model import Block -from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support -from modalities.running_env.running_env import RunningEnv - - -def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum: - try: - return enum_type[name] - except KeyError: - raise ValidationError(f"Invalid {enum_type} member name: {name}") - - -transformer_auto_wrapper_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - Block, - }, -) - - -class AutoWrapPolicies(Enum): - TRANSFORMER_AUTO_WRAP_POLICY = transformer_auto_wrapper_policy - - -class FSDPRunningEnvConfig(BaseModel): - process_group_backend: ProcessGroupBackendType - local_rank: int - mixed_precision_settings: MixedPrecisionSettings - sharding_strategy: ShardingStrategy - auto_wrap_policy: AutoWrapPolicies - - @validator("mixed_precision_settings", pre=True, always=True) - def parse_mixed_precision_setting_by_name(cls, name): - mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name( - name=name, enum_type=MixedPrecisionSettings - ) - if not has_bfloat_support() and ( - mixed_precision_settings == MixedPrecisionSettings.BF_16 - or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING - ): - raise ValueError("BF16 not supported in the current environment") - return mixed_precision_settings - - @validator("sharding_strategy", pre=True, always=True) - def parse_sharding_strategy_by_name(cls, name): - return parse_enum_by_name(name=name, enum_type=ShardingStrategy) - - @validator("auto_wrap_policy", pre=True, always=True) - def parse_auto_wrap_policy_by_name(cls, name): - return parse_enum_by_name(name=name, enum_type=AutoWrapPolicies) - - -class FSDPRunningEnv(RunningEnv): - def __init__( - self, - process_group_backend: ProcessGroupBackendType, - local_rank: int, - mixed_precision_settings: MixedPrecisionSettings, - sharding_strategy: ShardingStrategy, - auto_wrap_policy: AutoWrapPolicies, - ) -> None: - self.process_group_backend = process_group_backend - self.local_rank = local_rank - self.mixed_precision_settings = mixed_precision_settings - self.sharding_strategy = sharding_strategy - self.auto_wrap_policy = auto_wrap_policy - - def __enter__(self) -> "RunningEnv": - dist.init_process_group(self.process_group_backend.value) - torch.cuda.set_device(self.local_rank) - return self - - def __exit__(self, type, value, traceback): - pass # TODO uncomment part below - # dist.barrier() # TODO check for concurrency issues - # dist.destroy_process_group() - - def wrap_model(self, model: nn.Module, sync_module_states: bool) -> FSDP: - # model is on CPU before input to FSDP - fsdp_model = FSDP( - model, - auto_wrap_policy=self.auto_wrap_policy.value, - mixed_precision=self.mixed_precision_settings.value, - sharding_strategy=self.sharding_strategy, - device_id=torch.cuda.current_device(), - sync_module_states=sync_module_states, - ) - return fsdp_model - - -class RunningEnvTypes(LookupEnum): - FSDPRunningEnv = FSDPRunningEnv - - -class RunningEnvConfig(BaseModel): - type_hint: RunningEnvTypes - config: FSDPRunningEnvConfig diff --git a/src/modalities/running_env/running_env.py b/src/modalities/running_env/running_env.py deleted file mode 100644 index 8d121353..00000000 --- a/src/modalities/running_env/running_env.py +++ /dev/null @@ -1,15 +0,0 @@ -from abc import ABC, abstractmethod - -import torch.nn as nn - - -class RunningEnv(ABC, object): - def __enter__(self) -> "RunningEnv": - raise NotImplementedError - - def __exit__(self, type, value, traceback): - raise NotImplementedError - - @abstractmethod - def wrap_model(self, model: nn.Module, sync_module_states: bool) -> nn.Module: - raise NotImplementedError diff --git a/src/modalities/test.py b/src/modalities/test.py index ea16a091..f81c3630 100644 --- a/src/modalities/test.py +++ b/src/modalities/test.py @@ -3,7 +3,6 @@ from rich.progress import Progress with Progress() as progress: - task1 = progress.add_task("[red]Downloading...", total=1000) task2 = progress.add_task("[green]Processing...", total=1000) task3 = progress.add_task("[cyan]Cooking...", total=1000) @@ -12,4 +11,4 @@ progress.update(task1, advance=0.5) progress.update(task2, advance=0.3) progress.update(task3, advance=0.9) - time.sleep(0.02) \ No newline at end of file + time.sleep(0.02) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 12e8193c..6994af98 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -26,12 +26,12 @@ def __init__( local_rank: int, batch_progress_publisher: MessagePublisher[BatchProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], - gradient_acc_step: int, + gradient_acc_steps: int, ) -> None: self.local_rank = local_rank self.batch_progress_publisher = batch_progress_publisher self.evaluation_result_publisher = evaluation_result_publisher - self.gradient_acc_step = gradient_acc_step + self.gradient_acc_steps = gradient_acc_steps def _train_batch( self, @@ -43,10 +43,10 @@ def _train_batch( data_loader: LLMDataLoader, ) -> torch.Tensor: result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) / self.gradient_acc_step + loss = loss_fun(result_batch) / self.gradient_acc_steps loss.backward() - if (batch_id + 1) % self.gradient_acc_step == 0 or (batch_id + 1) == len(data_loader): + if (batch_id + 1) % self.gradient_acc_steps == 0 or (batch_id + 1) == len(data_loader): optimizer.step() optimizer.zero_grad() return loss @@ -63,9 +63,11 @@ def train( local_sample_id_to_global_sample_id: Callable[[int], int], ): model.train() - cummulated_loss = self._reset_loss() + cumulated_loss = self._reset_loss() thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() + device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") + # batch loop batch: DatasetBatch # TODO: why do we need a barrier here? @@ -86,14 +88,14 @@ def train( ) forward_backward_time_recorder.stop() # Save the batch loss - cummulated_loss[0] += batch_loss.item() - cummulated_loss[1] += len(batch) - batch_length_tensor = torch.tensor(len(batch)).to(torch.device(self.local_rank)) + cumulated_loss[0] += batch_loss.item() + cumulated_loss[1] += len(batch) + batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) self._publish_progress( batch_progress_publisher=self.batch_progress_publisher, local_batch_id=local_train_batch_id, - batch_size=train_loader.sampler_batch_size, + batch_size=train_loader.batch_size, dataloader_tag=train_loader.dataloader_tag, local_sample_id_to_global_sample_id=local_sample_id_to_global_sample_id, ) @@ -101,29 +103,27 @@ def train( # Check, if model should be evaluated if (local_train_batch_id + 1) % callback_interval_in_batches == 0: if local_train_batch_id > 0: - foward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to( - torch.device(self.local_rank) - ) + forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) forward_backward_time_recorder.reset() thoughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=foward_backward_time + key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time ) synced_num_samples = thoughput_aggregator.get_all_reduced_value( ThroughputAggregationKeys.NUM_SAMPLES ) - synced_foward_backward_time = thoughput_aggregator.get_all_reduced_value( + synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX ) - synced_num_samples_per_second = synced_num_samples / synced_foward_backward_time + synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP train_loss = Reducer.reduce( - tensor=cummulated_loss, + tensor=cumulated_loss, operation=dist.ReduceOp.SUM, post_processing_fun=lambda t: t[0] / t[1], ) local_train_sample_id = Trainer._get_local_sample_id( - batch_id=local_train_batch_id, batch_size=train_loader.sampler_batch_size + batch_id=local_train_batch_id, batch_size=train_loader.batch_size ) global_train_sample_id = local_sample_id_to_global_sample_id(local_train_sample_id) @@ -144,19 +144,19 @@ def train( model.train() # TODO early stopping - cummulated_loss = self._reset_loss() + cumulated_loss = self._reset_loss() # we start the time recoder here again to also capture the time spend loading # via the dataloader. forward_backward_time_recorder.start() def _reset_loss(self): # TODO: we should handle the device assignment more centrally. - cummulated_loss = torch.zeros(2) + cumulated_loss = torch.zeros(2) if torch.cuda.is_available(): - cummulated_loss = cummulated_loss.to(torch.device(self.local_rank)) + cumulated_loss = cumulated_loss.to(torch.device(self.local_rank)) else: - cummulated_loss = cummulated_loss.to("cpu") - return cummulated_loss + cumulated_loss = cumulated_loss.to("cpu") + return cumulated_loss @staticmethod def _publish_progress( diff --git a/src/modalities/util.py b/src/modalities/util.py index 027e14a1..7eafb921 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -1,16 +1,44 @@ import time +import warnings from datetime import datetime from enum import Enum from types import TracebackType -from typing import Callable, Dict, Generic, TypeVar +from typing import Callable, Dict, Generic, Type, TypeVar import torch import torch.distributed as dist +from pydantic import ValidationError from modalities.exceptions import TimeRecorderStateError from modalities.running_env.fsdp.reducer import Reducer +def get_callback_interval_in_batches_per_rank( + callback_interval_in_samples: int, local_train_micro_batch_size: int, world_size: int, gradient_acc_steps: int +): + num_local_train_micro_batches_exact = callback_interval_in_samples / local_train_micro_batch_size / world_size + num_local_train_micro_batches_ret = max( + callback_interval_in_samples // local_train_micro_batch_size // world_size, 1 + ) + if num_local_train_micro_batches_exact != num_local_train_micro_batches_ret: + warnings.warn( + f"Calculated callback_interval_in_batches_per_rank is not an integer." + f"Clipping {num_local_train_micro_batches_exact} to {num_local_train_micro_batches_ret} " + ) + assert ( + num_local_train_micro_batches_ret % gradient_acc_steps == 0 + ), "callback_interval_in_batches_per_rank must be divisible by gradient_acc_steps" + return num_local_train_micro_batches_ret + + +def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum: + try: + val = enum_type[name] + return val + except KeyError: + raise ValidationError(f"Invalid {enum_type} member name: {name}") + + def get_date_of_run(): """create date and time for file save uniqueness example: 2022-05-07__14-31-22' diff --git a/src/modalities/utils/generate_text.py b/src/modalities/utils/generate_text.py index 21d86c2c..d503b4fd 100755 --- a/src/modalities/utils/generate_text.py +++ b/src/modalities/utils/generate_text.py @@ -10,15 +10,16 @@ from pathlib import Path import torch -from omegaconf import OmegaConf from torch.nn import functional as F from transformers import PreTrainedTokenizer -from modalities.config.config import AppConfig -from modalities.resolver_register import ResolverRegister +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ComponentsInferenceModel, load_app_config_dict +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry chat_prefix = """ -This is a converstation between a user and a helpful bot, which answers the user's questsions as good as possible. +This is a conversation between a user and a helpful bot, which answers the user's questions as good as possible. user: What is 1+1? bot: 1+1 is 2. @@ -95,11 +96,15 @@ def main(model_path: Path, config_path: Path, tokenizer: PreTrainedTokenizer, ma state_dict = torch.load(path) print(f"using {model_path}") - config_dict = OmegaConf.load(config_path) - config_dict = OmegaConf.to_container(config_dict, resolve=True) - config = AppConfig.model_validate(config_dict) - resolvers = ResolverRegister(config=config) - model: torch.nn.Module = resolvers.build_component_by_config(config=config.model) + config_dict = load_app_config_dict(config_path) + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + components = component_factory.build_components( + config_dict=config_dict, components_model_type=ComponentsInferenceModel + ) + + model = components.wrapped_model + model.load_state_dict(state_dict) model.eval() @@ -109,11 +114,11 @@ def main(model_path: Path, config_path: Path, tokenizer: PreTrainedTokenizer, ma if chat is True: prompt = input("enter question> ").strip() prompt = chat_prefix + chat_prompt_template.format(prompt=prompt) - generate(model, tokenizer, prompt, config.model.config.block_size, max_new_tokens) + generate(model, tokenizer, prompt, model.config.block_size, max_new_tokens) else: prompt = input("enter prompt> ") print(prompt, end="") - generate(model, tokenizer, prompt, config.model.config.block_size, max_new_tokens) + generate(model, tokenizer, prompt, model.config.block_size, max_new_tokens) except KeyboardInterrupt: print("closing app...") break diff --git a/tests/checkpointing/gpt2_config.yaml b/tests/checkpointing/gpt2_config.yaml index ea401600..3865149e 100644 --- a/tests/checkpointing/gpt2_config.yaml +++ b/tests/checkpointing/gpt2_config.yaml @@ -1,24 +1,23 @@ -llm_model_conf: - sample_key: input_ids - prediction_key: "logits" - block_size: 1024 - vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: 12 - n_head: 12 - ffn_hidden: 2048 - n_embd: 768 - dropout: 0.0 - bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 - activation: fused_swiglu - epsilon: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -running_env_conf: - process_group_backend: "nccl" - local_rank: ${oc.env:LOCAL_RANK} +model: + component_key: model + variant_key: gpt2 + config: + sample_key: "input_ids" # TODO reference this + prediction_key: "logits" # TODO reference this + block_size: 256 # TODO reference this (same as sequence length) + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention: + attention_type: default_attention # pytorch_flash_attention + scaling_factor: 3 + activation: gelu + epsilon: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 diff --git a/tests/checkpointing/test_checkpoint_execution_functions.py b/tests/checkpointing/test_checkpoint_execution_functions.py index 83e0dbe0..229e703a 100644 --- a/tests/checkpointing/test_checkpoint_execution_functions.py +++ b/tests/checkpointing/test_checkpoint_execution_functions.py @@ -3,8 +3,10 @@ import pytest import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy -from src.modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing +from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing +from modalities.running_env.env_utils import MixedPrecisionSettings @pytest.mark.skip @@ -28,7 +30,12 @@ def test_get_paths_to_delete(tmp_path): # pytest temp path p.write_text(CONTENT) checkpointing = FSDPToDiscCheckpointing( - checkpoint_path=d, experiment_id=str(1), global_rank=0, model_wrapping_fn=dummy_method + checkpoint_path=d, + experiment_id=str(1), + global_rank=0, + block_names=["model"], + mixed_precision_settings=MixedPrecisionSettings.BF_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, ) files_paths_to_delete = checkpointing._get_paths_to_delete(global_train_sample_id=100) assert len(files_paths_to_delete) != 0 @@ -50,7 +57,9 @@ def test_delete_checkpoint(tmpdir): checkpoint_path=directory, experiment_id=experiment_id, global_rank=0, - model_wrapping_fn=dummy_method, + block_names=["model"], + mixed_precision_settings=MixedPrecisionSettings.BF_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, ) checkpointing._delete_checkpoint(global_train_sample_id=100) assert is_empty_directory((directory / experiment_id).__str__()) diff --git a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py index 13bface2..a04bd44a 100644 --- a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py +++ b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py @@ -1,69 +1,73 @@ +import os import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Generator +from typing import Dict import pytest import torch import torch.distributed as dist -from pydantic import BaseModel +import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy from torch.nn import CrossEntropyLoss from torch.optim import AdamW, Optimizer from modalities.__main__ import load_app_config_dict -from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing -from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Config -from modalities.running_env.fsdp.fsdp_running_env import FSDPRunningEnv, FSDPRunningEnvConfig, RunningEnv +from modalities.checkpointing.checkpointing_execution import CheckpointingEntityType, FSDPToDiscCheckpointing +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ProcessGroupBackendType +from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig +from modalities.models.model_factory import ModelFactory +from modalities.optimizers.optimizer_factory import OptimizerFactory +from modalities.running_env.cuda_env import CudaEnv +from modalities.running_env.env_utils import MixedPrecisionSettings # NOTE: We need to run the tests in a torch distributed environment with at least two GPUs. # CUDA_VISIBLE_DEVICES=0,1 torchrun --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 \ -# /path/to/pytest path/to/test_fsdp_to_disc_checkpointing.py +# $(which pytest) path/to/test_fsdp_to_disc_checkpointing.py _ROOT_DIR = Path(__file__).parents[1] -class ExperimentConfig(BaseModel): - llm_model_conf: GPT2Config # Named it llm_model_conf as model_ is a protected namespace in pydantic - running_env_conf: FSDPRunningEnvConfig - - -@pytest.mark.skip( - reason="Need to fix absolute path for config_file_path and needs to be run via " - "torchrun in a torch distributed environment (torchrun)" +@pytest.mark.skipif( + "RANK" not in os.environ or torch.cuda.device_count() < 2, + reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", ) class TestFSDPToDiscCheckpointing: - @pytest.fixture - def experiment_config(self) -> ExperimentConfig: - config_file_path = _ROOT_DIR / Path("tests/checkpointing/gpt2_config.yaml") + @pytest.fixture(scope="function") + def gpt2_model_config(self) -> GPT2LLMConfig: + config_file_path = Path("tests/checkpointing/gpt2_config.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) - experiment_config = ExperimentConfig.model_validate(config_dict) - return experiment_config + config = GPT2LLMConfig(**config_dict["model"]["config"]) + return config @pytest.fixture(scope="function") - def gpt2_model(self, experiment_config: ExperimentConfig) -> GPT2LLM: - model = GPT2LLM(config=experiment_config.llm_model_conf) + def gpt2_model(self, gpt2_model_config: GPT2LLMConfig) -> GPT2LLM: + config_dict = ComponentFactory.base_model_to_dict(gpt2_model_config) + model = GPT2LLM(**config_dict) return model @pytest.fixture(scope="function") - def gpt2_model_2(self, experiment_config: ExperimentConfig) -> GPT2LLM: - model = GPT2LLM(config=experiment_config.llm_model_conf) + def gpt2_model_2(self, gpt2_model_config: GPT2LLMConfig) -> GPT2LLM: + config_dict = ComponentFactory.base_model_to_dict(gpt2_model_config) + model = GPT2LLM(**config_dict) return model @pytest.fixture - def fsdp_running_env(self, experiment_config: ExperimentConfig) -> Generator[RunningEnv, Any, Any]: - running_env = FSDPRunningEnv(**dict(experiment_config.running_env_conf)) - with running_env as running_env: - yield running_env - - @pytest.fixture - def fsdp_wrapped_model(self, gpt2_model: GPT2LLM, fsdp_running_env) -> FSDP: - wrapped_model: FSDP = FSDPRunningEnv.wrap_model(gpt2_model, sync_module_states=True) + def fsdp_wrapped_model(self, gpt2_model: GPT2LLM) -> FSDP: + wrapped_model: FSDP = ModelFactory.get_fsdp_wrapped_model( + gpt2_model, + sync_module_states=True, + block_names=["GPT2Block"], + mixed_precision_settings=MixedPrecisionSettings.FP_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, + ) return wrapped_model @pytest.fixture - def optimizer(self, fsdp_wrapped_model: GPT2LLM) -> Optimizer: - optimizer = AdamW(fsdp_wrapped_model.parameters(), lr=0.001) + def optimizer(self, fsdp_wrapped_model: nn.Module) -> Optimizer: + optimizer = OptimizerFactory.get_adam_w(wrapped_model=fsdp_wrapped_model, lr=0.001) return optimizer @pytest.fixture @@ -71,20 +75,23 @@ def temporary_checkpoint_folder_path(self): with tempfile.TemporaryDirectory() as tmp_dir_path: yield Path(tmp_dir_path) + @pytest.fixture(autouse=True) + def cuda_env_context(self): + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + yield + @staticmethod - def _generate_batch(experiment_config: ExperimentConfig): + def _generate_batch(gpt2_model_config: GPT2LLMConfig): # prepare input and targets - data = torch.randint( - 0, experiment_config.llm_model_conf.vocab_size, (8, experiment_config.llm_model_conf.block_size + 1) - ).cuda() - batch_input_ids_dict = {experiment_config.llm_model_conf.sample_key: data[:, :-1]} + data = torch.randint(0, gpt2_model_config.vocab_size, (8, gpt2_model_config.block_size + 1)).cuda() + batch_input_ids_dict = {gpt2_model_config.sample_key: data[:, :-1]} batch_target_ids = data[:, 1:] batch_target_ids = batch_target_ids.contiguous() return batch_input_ids_dict, batch_target_ids @staticmethod def _forward_backward_pass( - experiment_config: ExperimentConfig, + gpt2_model_config: GPT2LLMConfig, model: FSDP, optimizer: Optimizer, batch_input_ids_dict: Dict, @@ -96,7 +103,7 @@ def _forward_backward_pass( optimizer.zero_grad() # forward pass - predictions = model.forward(inputs=batch_input_ids_dict)[experiment_config.llm_model_conf.prediction_key] + predictions = model.forward(inputs=batch_input_ids_dict)[gpt2_model_config.prediction_key] predictions = predictions.contiguous() # backward pass loss = ce_loss(predictions.view(-1, predictions.size(-1)), batch_target_ids.view(-1)) @@ -147,25 +154,27 @@ def test_save_checkpoint_after_backward_pass( optimizer: Optimizer, temporary_checkpoint_folder_path: Path, gpt2_model_2: GPT2LLM, - experiment_config: ExperimentConfig, + gpt2_model_config: GPT2LLMConfig, ): experiment_id = "0" - global_train_batch_id = 1 + global_train_sample_id = 1 checkpointing = FSDPToDiscCheckpointing( checkpoint_path=temporary_checkpoint_folder_path, experiment_id=experiment_id, global_rank=dist.get_rank(), - model_wrapping_fn=FSDPRunningEnv.wrap_model, + block_names=["GPT2Block"], + mixed_precision_settings=MixedPrecisionSettings.FP_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, ) untrained_model_parameters = [p.clone() for p in fsdp_wrapped_model.parameters()] untrained_optimizer_state_dict = deepcopy(optimizer.state_dict()) # run backward pass - batch_input_ids_dict, batch_target_ids = self._generate_batch(experiment_config) + batch_input_ids_dict, batch_target_ids = self._generate_batch(gpt2_model_config) self._forward_backward_pass( - experiment_config=experiment_config, + gpt2_model_config=gpt2_model_config, model=fsdp_wrapped_model, optimizer=optimizer, batch_input_ids_dict=batch_input_ids_dict, @@ -176,23 +185,28 @@ def test_save_checkpoint_after_backward_pass( # save model and optimizer before backward pass checkpointing._save_checkpoint( - model=fsdp_wrapped_model, optimizer=optimizer, global_train_batch_id=global_train_batch_id + model=fsdp_wrapped_model, optimizer=optimizer, global_train_sample_id=global_train_sample_id ) # load the model checkpoint - fsdp_wrapped_model_2 = checkpointing.load_model_checkpoint( - model=gpt2_model_2, + model_checkpointing_path = checkpointing._get_checkpointing_path( experiment_id=experiment_id, - global_train_batch_id=global_train_batch_id, + global_train_sample_id=global_train_sample_id, + entity_type=CheckpointingEntityType.MODEL, + ) + fsdp_wrapped_model_2 = checkpointing.load_model_checkpoint( + model=gpt2_model_2, file_path=model_checkpointing_path ) optimizer_2 = AdamW(fsdp_wrapped_model_2.parameters(), lr=0.001) - checkpointing.load_optimizer_checkpoint( - optimizer=optimizer_2, - model=fsdp_wrapped_model_2, + optimizer_checkpointing_path = checkpointing._get_checkpointing_path( experiment_id=experiment_id, - global_train_batch_id=global_train_batch_id, + global_train_sample_id=global_train_sample_id, + entity_type=CheckpointingEntityType.OPTIMIZER, + ) + checkpointing.load_optimizer_checkpoint( + optimizer=optimizer_2, wrapped_model=fsdp_wrapped_model_2, file_path=optimizer_checkpointing_path ) loaded_and_updated_model_parameters = [p.clone() for p in fsdp_wrapped_model_2.parameters()] @@ -218,17 +232,17 @@ def test_save_checkpoint_after_backward_pass( # we do another forward/backward pass and check # if the weights are equally updated for the loaded model as for the not-loaded model # run backward pass - batch_input_ids_dict, batch_target_ids = self._generate_batch(experiment_config) + batch_input_ids_dict, batch_target_ids = self._generate_batch(gpt2_model_config) loss_1 = self._forward_backward_pass( - experiment_config=experiment_config, + gpt2_model_config=gpt2_model_config, model=fsdp_wrapped_model, optimizer=optimizer, batch_input_ids_dict=batch_input_ids_dict, batch_target_ids=batch_target_ids, ) loss_2 = self._forward_backward_pass( - experiment_config=experiment_config, + gpt2_model_config=gpt2_model_config, model=fsdp_wrapped_model_2, optimizer=optimizer_2, batch_input_ids_dict=batch_input_ids_dict, diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/config/components.py b/tests/config/components.py new file mode 100644 index 00000000..67c9e9a3 --- /dev/null +++ b/tests/config/components.py @@ -0,0 +1,48 @@ +from enum import Enum +from typing import List + + +class Component_V_W_X_IF: + def print(self) -> None: + print("ComponentIF") + + +# Dependencies + + +class ComponentV(Component_V_W_X_IF): + def __init__(self, val_v: str) -> None: + self.val_v = val_v + + +class ComponentW(Component_V_W_X_IF): + def __init__(self, val_w: str) -> None: + self.val_w = val_w + + +# Components + + +class ComponentX(Component_V_W_X_IF): + def __init__(self, val_x: str, single_dependency: Component_V_W_X_IF) -> None: + self.val_x = val_x + self.single_dependency = single_dependency + + +class ComponentY: + def __init__(self, val_y: str, multi_dependency: List[Component_V_W_X_IF]) -> None: + self.val_y = val_y + self.multi_dependency = multi_dependency + + +class ComponentZ: + def __init__(self, val_z: str) -> None: + self.val_z = val_z + + +class ComponentTypes(Enum): + COMP_V = ComponentV + COMP_W = ComponentW + COMP_X = ComponentX + COMP_Y = ComponentY + COMP_Z = ComponentZ diff --git a/tests/config/configs.py b/tests/config/configs.py new file mode 100644 index 00000000..9b59b748 --- /dev/null +++ b/tests/config/configs.py @@ -0,0 +1,30 @@ +from typing import Annotated, List + +from pydantic import BaseModel + +from modalities.config.config import PydanticThirdPartyTypeIF +from tests.config.components import Component_V_W_X_IF + +PydanticComponent_V_W_X_IF_Type = Annotated[Component_V_W_X_IF, PydanticThirdPartyTypeIF(Component_V_W_X_IF)] + + +class CompVConfig(BaseModel): + val_v: str + + +class CompWConfig(BaseModel): + val_w: str + + +class CompXConfig(BaseModel): + val_x: str + single_dependency: PydanticComponent_V_W_X_IF_Type + + +class CompYConfig(BaseModel): + val_y: str + multi_dependency: List[PydanticComponent_V_W_X_IF_Type] + + +class CompZConfig(BaseModel): + val_z: str diff --git a/tests/config/custom_components.py b/tests/config/custom_components.py new file mode 100644 index 00000000..dd849e0d --- /dev/null +++ b/tests/config/custom_components.py @@ -0,0 +1,35 @@ +from abc import ABC +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, validator + + +class CustomComponent1: + def __init__(self, val_1: str) -> None: + self.val_1 = val_1 + + +class CustomComponentTypes(Enum): + CUSTOM_COMP_1 = CustomComponent1 + + +class CustomCompConfigABC(BaseModel, ABC): + # TODO make this a string and then implement the mapping + # to the class outside of the basemodel (i.e. in the factory) + type_hint: Enum + + @validator("type_hint", pre=True, allow_reuse=True, check_fields=False) + def _string_to_enum(cls, key: str): + if isinstance(key, str): + try: + key = CustomComponentTypes[key] + except KeyError as e: + raise ValueError(f"{key} is not a valid ComponentType") from e + return key + return key + + +class CustomComp1Config(CustomCompConfigABC): + type_hint: Literal[CustomComponentTypes.CUSTOM_COMP_1] + val_1: str diff --git a/tests/config/test_component_factory.py b/tests/config/test_component_factory.py new file mode 100644 index 00000000..94644da5 --- /dev/null +++ b/tests/config/test_component_factory.py @@ -0,0 +1,111 @@ +from pathlib import Path + +import pytest + +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import load_app_config_dict +from modalities.registry.components import ComponentEntity +from modalities.registry.registry import Registry +from tests.config.components import ComponentV, ComponentW, ComponentX, ComponentY +from tests.config.configs import CompVConfig, CompWConfig, CompXConfig, CompYConfig + + +@pytest.fixture(scope="function") +def component_factory() -> ComponentFactory: + components = [ + ComponentEntity("COMP_V", "default", ComponentV, CompVConfig), + ComponentEntity("COMP_W", "default", ComponentW, CompWConfig), + ComponentEntity("COMP_X", "default", ComponentX, CompXConfig), + ComponentEntity("COMP_Y", "default", ComponentY, CompYConfig), + ] + + registry = Registry(components=components) + component_factory = ComponentFactory(registry=registry) + return component_factory + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_backward_reference.yaml"), + Path("tests/config/test_configs/config_forward_reference.yaml"), + ], +) +def test_backward_reference(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_x_1", "comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + + # make sure that the reference is not identical, despite both being of type COMP_W + assert components["comp_x_1"].single_dependency != components["comp_y_1"].multi_dependency[0] + # make sure that the reference is identical, since we are referencing comp_x_1 in the multi depencency of comp_y_1 + assert components["comp_x_1"] == components["comp_y_1"].multi_dependency[2] + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_non_existing_reference.yaml"), + ], +) +def test_non_existing_reference(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_x_1", "comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + with pytest.raises(KeyError): + component_factory._build_config(config_dict=config_dict, component_names=component_names) + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_hierarchical_list_component.yaml"), + ], +) +def test_hierarchical_component_instantiation(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + + assert isinstance(components["comp_y_1"].multi_dependency[0], ComponentW) + assert isinstance(components["comp_y_1"].multi_dependency[1], ComponentV) + assert isinstance(components["comp_y_1"], ComponentY) + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_hierarchical_list_component.yaml"), + ], +) +def test_component_filter(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + assert "comp_y_1" in components + + component_names += "abc" + with pytest.raises(KeyError): + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_single_component.yaml"), + ], +) +def test_single_component(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["custom_comp_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + assert "custom_comp_1" in components diff --git a/tests/config/test_configs/config_backward_reference.yaml b/tests/config/test_configs/config_backward_reference.yaml new file mode 100644 index 00000000..bf82cb7d --- /dev/null +++ b/tests/config/test_configs/config_backward_reference.yaml @@ -0,0 +1,28 @@ +comp_x_1: + component_key: COMP_X + variant_key: default + config: + val_x: "some other value X" + single_dependency: + component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" + - instance_key: comp_x_1 + pass_type: BY_REFERENCE + diff --git a/tests/config/test_configs/config_forward_reference.yaml b/tests/config/test_configs/config_forward_reference.yaml new file mode 100644 index 00000000..d0ff73f2 --- /dev/null +++ b/tests/config/test_configs/config_forward_reference.yaml @@ -0,0 +1,27 @@ +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" + - instance_key: comp_x_1 + pass_type: BY_REFERENCE + +comp_x_1: + component_key: COMP_X + variant_key: default + config: + val_x: "some other value X" + single_dependency: + component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" \ No newline at end of file diff --git a/tests/config/test_configs/config_hierarchical_list_component.yaml b/tests/config/test_configs/config_hierarchical_list_component.yaml new file mode 100644 index 00000000..1298b590 --- /dev/null +++ b/tests/config/test_configs/config_hierarchical_list_component.yaml @@ -0,0 +1,15 @@ + +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" \ No newline at end of file diff --git a/tests/config/test_configs/config_non_existing_reference.yaml b/tests/config/test_configs/config_non_existing_reference.yaml new file mode 100644 index 00000000..91bfa504 --- /dev/null +++ b/tests/config/test_configs/config_non_existing_reference.yaml @@ -0,0 +1,17 @@ +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" + - instance_key: comp_x_1 + pass_type: BY_REFERENCE + diff --git a/tests/config/test_configs/config_single_component.yaml b/tests/config/test_configs/config_single_component.yaml new file mode 100644 index 00000000..01bedf8a --- /dev/null +++ b/tests/config/test_configs/config_single_component.yaml @@ -0,0 +1,5 @@ +custom_comp_1: + component_key: COMP_V + variant_key: default + config: + val_v: "some value v" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f94133ce..1ee049ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import os import pickle from pathlib import Path +from typing import Dict from unittest.mock import MagicMock import pytest @@ -10,9 +11,8 @@ from torch.utils.data.sampler import BatchSampler, SequentialSampler from transformers import GPT2TokenizerFast -from modalities.__main__ import load_app_config_dict from modalities.checkpointing.checkpointing import CheckpointingIF -from modalities.config.config import AppConfig +from modalities.config.config import load_app_config_dict from modalities.dataloader.create_index import IndexGenerator from modalities.dataloader.dataloader import LLMDataLoader from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader @@ -30,10 +30,11 @@ def dummy_packed_data_path(tmpdir) -> Path: data = b"" header_size_in_bytes = 8 - int_size_in_bytes = 4 + token_size_in_bytes = 4 tokens = list(range(20)) - data += (len(tokens) * int_size_in_bytes).to_bytes(header_size_in_bytes, byteorder="big") - data += b"".join([t.to_bytes(int_size_in_bytes, byteorder="big") for t in tokens]) + data += (len(tokens) * token_size_in_bytes).to_bytes(header_size_in_bytes, byteorder="big") + data += token_size_in_bytes.to_bytes(4, byteorder="big") + data += b"".join([t.to_bytes(token_size_in_bytes, byteorder="big") for t in tokens]) index = [(4, 24), (28, 40), (68, 12), (80, 4)] # [(index,len), ...] -> in 4 bytes #lengths: 6,10,3,1 data += pickle.dumps(index) dummy_packed_data_path = Path(tmpdir, "dummy.pbin") @@ -42,14 +43,13 @@ def dummy_packed_data_path(tmpdir) -> Path: @pytest.fixture -def dummy_config(monkeypatch) -> AppConfig: +def dummy_config(monkeypatch) -> Dict: monkeypatch.setenv("RANK", "0") monkeypatch.setenv("LOCAL_RANK", "0") monkeypatch.setenv("WORLD_SIZE", "1") dummy_config_path = _ROOT_DIR / Path("config_files/config_lorem_ipsum.yaml") config_dict = load_app_config_dict(dummy_config_path) - app_config = AppConfig.model_validate(config_dict) - return app_config + return config_dict, dummy_config_path @dataclasses.dataclass @@ -77,7 +77,7 @@ def indexed_dummy_data_path(dummy_data_path) -> DataPathCollection: @pytest.fixture def gpt2_tokenizer() -> GPT2TokenizerFast: - default_gpt2_tokenizer_path = Path(__file__).parents[1] / Path("data", "tokenizer", "tokenizer.json") + default_gpt2_tokenizer_path = Path(__file__).parents[1] / Path("data", "tokenizer", "tokenizer_gpt2.json") assert default_gpt2_tokenizer_path.is_file() return GPT2TokenizerFast(tokenizer_file=str(default_gpt2_tokenizer_path)) @@ -123,7 +123,7 @@ def trainer(progress_publisher_mock): local_rank=int(os.getenv("LOCAL_RANK")), batch_progress_publisher=progress_publisher_mock, evaluation_result_publisher=progress_publisher_mock, - gradient_acc_step=1, + gradient_acc_steps=1, ) diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 399226d8..3676cb4f 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,11 +1,15 @@ +from typing import Dict + import torch +from pydantic import BaseModel from torch.utils.data import BatchSampler, SequentialSampler -from modalities.config.config import AppConfig +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import PydanticLLMDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.samplers import ResumableBatchSampler -from modalities.resolver_register import ResolverRegister +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry def test_resumable_dataloader() -> LLMDataLoader: @@ -21,18 +25,25 @@ def test_resumable_dataloader() -> LLMDataLoader: assert (flat_samples == original_samples).all() -def test_dataloader_from_config(dummy_config: AppConfig): - resolvers = ResolverRegister(config=dummy_config) +def test_dataloader_from_config(dummy_config: Dict): start_index = 2 - dataloader_1: LLMDataLoader = DataloaderFactory.get_dataloader( - resolvers=resolvers, config=dummy_config.data.train_dataloader, skip_num_batches=start_index - ) - dataset = dataloader_1.dataset + config_dict, _ = dummy_config + config_dict["train_dataloader"]["config"]["skip_num_batches"] = start_index - distributed_sampler = dataloader_1.batch_sampler.underlying_batch_sampler.sampler - batch_sampler = BatchSampler( - sampler=distributed_sampler, batch_size=dataloader_1.sampler_batch_size, drop_last=False + class DataloaderTestModel(BaseModel): + train_dataloader: PydanticLLMDataLoaderIFType + + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + components: DataloaderTestModel = component_factory.build_components( + config_dict=config_dict, components_model_type=DataloaderTestModel ) + + dataloader_1: LLMDataLoader = components.train_dataloader + dataset = dataloader_1.dataset + resumable_batch_sampler: ResumableBatchSampler = dataloader_1.batch_sampler + distributed_sampler = resumable_batch_sampler.underlying_batch_sampler.sampler + batch_sampler = BatchSampler(sampler=distributed_sampler, batch_size=dataloader_1.batch_size, drop_last=False) dataloader_2 = LLMDataLoader( dataloader_tag="train", dataset=dataset, batch_sampler=batch_sampler, collate_fn=dataloader_1.collate_fn ) @@ -40,7 +51,7 @@ def test_dataloader_from_config(dummy_config: AppConfig): samples_1 = [batch for _, batch in zip(range(10), dataloader_1)] samples_2 = [batch for _, batch in zip(range(10), dataloader_2)] - assert dataloader_1.sampler_batch_size * len(dataloader_2) == len(dataset) + assert dataloader_1.batch_size * len(dataloader_2) == len(dataset) assert len(dataloader_1) + start_index == len(dataloader_2) diff --git a/tests/dataloader/test_large_file_lines_reader.py b/tests/dataloader/test_large_file_lines_reader.py index e91287ed..a2dc546e 100644 --- a/tests/dataloader/test_large_file_lines_reader.py +++ b/tests/dataloader/test_large_file_lines_reader.py @@ -1,6 +1,7 @@ import json import pickle import tempfile +import warnings from pathlib import Path import pytest @@ -37,9 +38,11 @@ def generate_data_index_file(data_path: Path, **kwargs): dummy_dst_path.unlink(missing_ok=True) indexer.create_index(dummy_dst_path) - with pytest.raises(ValueError): - generate_data_index_file(plain_text_data_path) - generate_data_index_file(plain_text_data_path, drop_faulty_entries=True) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + with pytest.raises(ValueError): + generate_data_index_file(plain_text_data_path) + generate_data_index_file(plain_text_data_path, drop_faulty_entries=True) generate_data_index_file(jsonl_data_path) index = pickle.loads(dummy_dst_path.read_bytes()) diff --git a/tests/dataloader/test_packed_dataset.py b/tests/dataloader/test_packed_dataset.py index 64df0e9c..996565da 100644 --- a/tests/dataloader/test_packed_dataset.py +++ b/tests/dataloader/test_packed_dataset.py @@ -1,6 +1,9 @@ +from pathlib import Path + +import numpy as np import pytest -from modalities.dataloader.create_packed_data import PackedDataGenerator +from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data from modalities.dataloader.dataset import PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron @@ -35,11 +38,10 @@ def test_packed_continuous_dataset_missing_file(dummy_packed_data_path): PackedMemMapDatasetContinuous(dummy_packed_data_path, block_size=10, sample_key="input_ids") -@pytest.mark.parametrize("max_num_of_tokens, expected_index_size", [(None, 12), (10, 1)]) -def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer, max_num_of_tokens, expected_index_size): +def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer): block_size = 5 packed_generator = PackedDataGenerator( - src_path=indexed_dummy_data_path.raw_data_path, tokenizer=gpt2_tokenizer, max_number_of_tokens=max_num_of_tokens + src_path=indexed_dummy_data_path.raw_data_path, tokenizer=gpt2_tokenizer, number_of_processes=2 ) default_packed_dataset_path = packed_generator._default_destination_path() assert not default_packed_dataset_path.is_file() @@ -51,10 +53,35 @@ def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer, max_num_ start_of_jsonl_content = "0 Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor" tokenized_start_of_jsonl_content = gpt2_tokenizer(start_of_jsonl_content)["input_ids"] packed_dataset_iterator = iter(packed_dataset) - assert tokenized_start_of_jsonl_content[:block_size] == next(packed_dataset_iterator)["input_ids"] - assert tokenized_start_of_jsonl_content[block_size : 2 * block_size] == next(packed_dataset_iterator)["input_ids"] - assert len(packed_dataset.index_base) == expected_index_size + np.testing.assert_equal(tokenized_start_of_jsonl_content[:block_size], next(packed_dataset_iterator)["input_ids"]) + np.testing.assert_equal( + tokenized_start_of_jsonl_content[block_size : 2 * block_size], next(packed_dataset_iterator)["input_ids"] + ) + assert len(packed_dataset._embedded_stream_data.index_base) == 12 # check validity of index section in packed dataset - for idx, (offset, entry_length) in enumerate(packed_dataset.index_base[:-1]): - assert offset + entry_length == packed_dataset.index_base[idx + 1][0] + for idx, (offset, entry_length) in enumerate(packed_dataset._embedded_stream_data.index_base[:-1]): + assert offset + entry_length == packed_dataset._embedded_stream_data.index_base[idx + 1][0] + + +def test_join_packed_datasets(dummy_packed_data_path, tmpdir): + packed_data_clones = [Path(tmpdir, f"clone{i}.pbin") for i in range(3)] + for clone in packed_data_clones: + clone.write_bytes(dummy_packed_data_path.read_bytes()) + + joined_target_file = Path(tmpdir, "joined.pbin") + + stream_data = list(map(EmbeddedStreamData, packed_data_clones)) + join_embedded_stream_data(stream_data, joined_target_file) + + loaded_joint_data = EmbeddedStreamData(joined_target_file) + assert loaded_joint_data + assert loaded_joint_data.data_len == sum(d.data_len for d in stream_data) + + loaded_dataset = PackedMemMapDatasetContinuous(joined_target_file, block_size=2, sample_key="whatever") + original_datasets = [ + PackedMemMapDatasetContinuous(p, block_size=2, sample_key="whatever") for p in packed_data_clones + ] + assert [v for batch in loaded_dataset for v in batch["whatever"]] == [ + v for ds in original_datasets for batch in ds for v in batch["whatever"] + ] diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py deleted file mode 100644 index 2cffe687..00000000 --- a/tests/test_evaluation.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -from transformers import AutoConfig, AutoModelForCausalLM - -from modalities.config.config import PretrainedGPTConfig -from modalities.models.gpt2.gpt2_model import ( - ActivationType, - AttentionConfig, - AttentionType, - GPT2Config, - WeightInitailizationConfig, -) -from modalities.models.gpt2.pretrained_gpt_model import PretrainedGPTModel - - -def test_pretrained_gpt_model(tmp_path): - # setup config and model - attention_config = AttentionConfig(attention_type=AttentionType("default_attention"), scaling_factor=3) - config = GPT2Config( - block_size=12, - vocab_size=128, - n_layer=2, - n_head=2, - n_embd=128, - ffn_hidden=128, - dropout=0.01, - bias=True, - attention=attention_config, - activation=ActivationType.GELU, - epsilon=1e-5, - sample_key="input_ids", - prediction_key="logits", - weight_init=WeightInitailizationConfig(mean=0, std=0.02), - ) - pretrained_config = PretrainedGPTConfig(config=config) - - model = PretrainedGPTModel(config=pretrained_config) - model.save_pretrained(tmp_path) - model = model.eval() - - # register config and model - AutoConfig.register("modalities_gpt2", PretrainedGPTConfig) - AutoModelForCausalLM.register(PretrainedGPTConfig, PretrainedGPTModel) - - # load saved model - loaded_model = AutoModelForCausalLM.from_pretrained(tmp_path) - loaded_model = loaded_model.eval() - - # check that model before and after loading return the same output - test_tensor = torch.randint(10, size=(5, 10)) - output_before_loading = model.forward(test_tensor) - output_after_loading = loaded_model.forward(test_tensor) - assert (output_after_loading == output_before_loading).all() diff --git a/tests/test_gym.py b/tests/test_gym.py index f03d9b92..d650f736 100644 --- a/tests/test_gym.py +++ b/tests/test_gym.py @@ -1,10 +1,9 @@ -from unittest.mock import call, patch +from unittest.mock import call import torch from modalities.batch import DatasetBatch from modalities.gym import Gym -from modalities.running_env.fsdp.reducer import Reducer def test_run_cpu_only( @@ -37,15 +36,13 @@ def test_run_cpu_only( llm_data_loader_mock.__len__ = lambda _: num_batches gym = Gym(trainer=trainer, evaluator=evaluator_mock, loss_fun=loss_mock, num_ranks=num_ranks) - with patch.object(Reducer, "reduce", return_value=None) as reduce_mock: - gym.run( - model=nn_model_mock, - optimizer=optimizer_mock, - callback_interval_in_batches=int(num_batches), - train_data_loader=llm_data_loader_mock, - evaluation_data_loaders=[], - checkpointing=checkpointing_mock, - ) - nn_model_mock.forward.assert_has_calls([call(b.samples) for b in batches]) - optimizer_mock.step.assert_called() - reduce_mock.assert_called() + gym.run( + model=nn_model_mock, + optimizer=optimizer_mock, + callback_interval_in_batches=int(num_batches), + train_data_loader=llm_data_loader_mock, + evaluation_data_loaders=[], + checkpointing=checkpointing_mock, + ) + nn_model_mock.forward.assert_has_calls([call(b.samples) for b in batches]) + optimizer_mock.step.assert_called() diff --git a/tests/test_loss_functions.py b/tests/test_loss_functions.py new file mode 100644 index 00000000..8825f15c --- /dev/null +++ b/tests/test_loss_functions.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from modalities.batch import InferenceResultBatch +from modalities.loss_functions import NCELoss, nce_loss + + +@pytest.fixture +def dummy_result_batch() -> InferenceResultBatch: + predictions = {"embedding": torch.rand(1024, 512)} + targets = {"target": torch.zeros(1024, 512)} + batch_dim = 1024 + result_batch = InferenceResultBatch(targets, predictions, batch_dim) + return result_batch + + +# calculating asymmetric NCELoss between a batch of embeddings and itself --> zero +@pytest.mark.parametrize("key", ["embedding"]) +def test_asymm_NCELoss_is_zero(dummy_result_batch, key): + loss_func = NCELoss(prediction_key1=key, prediction_key2=key) + assert loss_func(dummy_result_batch) <= 10e-6 + + +# calculating nce_loss for two randomly generated batch of embeddings (manually calculated) +@pytest.mark.parametrize( + "embedding1,embedding2", + [ + ( + torch.Tensor([[0.38, 0.18], [0.36, 0.66], [0.72, 0.09]]), + torch.Tensor([[0.48, 0.01], [0.54, 0.28], [0.08, 0.34]]), + ) + ], +) +def test_nce_loss_correctness(embedding1, embedding2): + unidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=True, temperature=1.0) + bidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=False, temperature=1.0) + assert unidirectional_loss == pytest.approx(1.1300, 0.0001) + assert bidirectional_loss == pytest.approx(2.2577, 0.0001) diff --git a/tests/test_main.py b/tests/test_main.py index 92ab8f9f..b81d9d7c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,20 +4,12 @@ from modalities.__main__ import Main -def no_gpu_available() -> bool: - return not torch.cuda.is_available() - - -@pytest.mark.skipif( - no_gpu_available(), reason="This e2e test verifies a GPU-Setup and uses components, which do not support CPU-only." -) +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This e2e test requires 1 GPU.") def test_e2e_training_run_wout_ckpt(monkeypatch, indexed_dummy_data_path, dummy_config): # patch in env variables monkeypatch.setenv("MASTER_ADDR", "localhost") monkeypatch.setenv("MASTER_PORT", "9948") - - dummy_config.data.train_dataloader.config.dataset.config.raw_data_path = indexed_dummy_data_path.raw_data_path - for val_dataloader_config in dummy_config.data.eval_dataloaders: - val_dataloader_config.config.dataset.config.raw_data_path = indexed_dummy_data_path.raw_data_path - main = Main(dummy_config) + config_dict, config_path = dummy_config + config_dict["train_dataset"]["config"]["raw_data_path"] = indexed_dummy_data_path.raw_data_path + main = Main(config_dict, config_path) main.run() From 10110c8456057e977ab0e492383a8413d6868b97 Mon Sep 17 00:00:00 2001 From: Luzian Hahn Date: Mon, 11 Mar 2024 10:16:20 +0100 Subject: [PATCH 3/9] chore: align configs with new GQA keys --- config_files/config.yaml | 4 ++-- config_files/config_example_mem_map_dataset.yaml | 3 ++- config_files/config_example_openGPTx_dataset.yaml | 3 ++- config_files/config_lorem_ipsum.yaml | 3 ++- examples/getting_started/example_config.yaml | 3 ++- examples/library_usage/config_lorem_ipsum.yaml | 3 ++- tests/checkpointing/gpt2_config.yaml | 3 ++- 7 files changed, 14 insertions(+), 8 deletions(-) diff --git a/config_files/config.yaml b/config_files/config.yaml index 6925155b..cf9cc961 100644 --- a/config_files/config.yaml +++ b/config_files/config.yaml @@ -142,8 +142,8 @@ model: prediction_key: "logits" block_size: ${data.sequence_len} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: 12 - n_head: 12 + n_layer_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index 62498856..c5671efe 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -138,7 +138,8 @@ model: block_size: ${settings.training.sequence_length} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 12 - n_head: 12 + n_head_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 diff --git a/config_files/config_example_openGPTx_dataset.yaml b/config_files/config_example_openGPTx_dataset.yaml index b5f3eef6..f817e4c1 100644 --- a/config_files/config_example_openGPTx_dataset.yaml +++ b/config_files/config_example_openGPTx_dataset.yaml @@ -145,7 +145,8 @@ model: block_size: ${data.sequence_len} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 12 - n_head: 12 + n_head_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index c9f01291..07d71220 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -191,7 +191,8 @@ model: block_size: 256 # TODO reference this (same as sequence length) vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 2 - n_head: 4 + n_head_q: 8 + n_head_kv: 2 ffn_hidden: 128 n_embd: 128 dropout: 0.0 diff --git a/examples/getting_started/example_config.yaml b/examples/getting_started/example_config.yaml index b4c788f6..104815b4 100644 --- a/examples/getting_started/example_config.yaml +++ b/examples/getting_started/example_config.yaml @@ -114,7 +114,8 @@ model: block_size: ${data.sequence_len} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 12 - n_head: 12 + n_head_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 diff --git a/examples/library_usage/config_lorem_ipsum.yaml b/examples/library_usage/config_lorem_ipsum.yaml index f41a2507..61f34cb7 100644 --- a/examples/library_usage/config_lorem_ipsum.yaml +++ b/examples/library_usage/config_lorem_ipsum.yaml @@ -193,7 +193,8 @@ model: block_size: 256 # TODO reference this (same as sequence length) vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 2 - n_head: 4 + n_head_q: 4 + n_head_kv: 4 ffn_hidden: 128 n_embd: 128 dropout: 0.0 diff --git a/tests/checkpointing/gpt2_config.yaml b/tests/checkpointing/gpt2_config.yaml index 3865149e..1b2d5c80 100644 --- a/tests/checkpointing/gpt2_config.yaml +++ b/tests/checkpointing/gpt2_config.yaml @@ -7,7 +7,8 @@ model: block_size: 256 # TODO reference this (same as sequence length) vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 2 - n_head: 4 + n_head_q: 4 + n_head_kv: 4 ffn_hidden: 128 n_embd: 128 dropout: 0.0 From 683234325705971b4e845883c95e6189ae70e0b7 Mon Sep 17 00:00:00 2001 From: Luzian Hahn Date: Mon, 11 Mar 2024 10:30:05 +0100 Subject: [PATCH 4/9] docs: add potential removal marker for "scaling_factor" --- src/modalities/models/gpt2/gpt2_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index e326b690..fbc4f66f 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -26,6 +26,7 @@ class ActivationType(str, Enum): class AttentionConfig(BaseModel): attention_type: AttentionType + # TODO: Is this parameter really necessary? Or can it be always 3? scaling_factor: Annotated[int, Field(strict=True, ge=1)] From d20005d4a0a22c147c924fe4e66894e4ff673f88 Mon Sep 17 00:00:00 2001 From: Luzian Hahn Date: Mon, 11 Mar 2024 11:12:53 +0100 Subject: [PATCH 5/9] test: add attention forward pass test for GQA --- tests/models/test_attention.py | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/models/test_attention.py diff --git a/tests/models/test_attention.py b/tests/models/test_attention.py new file mode 100644 index 00000000..f53acdd3 --- /dev/null +++ b/tests/models/test_attention.py @@ -0,0 +1,40 @@ +import pytest +import torch + +from modalities.models.gpt2.gpt2_model import AttentionConfig, AttentionType, CausalSelfAttention + + +@pytest.mark.parametrize( + "n_head_q, n_head_kv, n_embd, att_type, successful", + [ + (4, 4, 32, AttentionType.DEFAULT_ATTENTION, True), + (8, 2, 32, AttentionType.DEFAULT_ATTENTION, True), + (9, 8, 32, AttentionType.DEFAULT_ATTENTION, False), + (8, 3, 32, AttentionType.DEFAULT_ATTENTION, False), + ], +) +def test_grouped_query_attention_forward(n_head_q, n_head_kv, n_embd, att_type, successful): + batch_size = 2 + block_size = 10 + embedding_shape = (batch_size, block_size, n_embd) + embedded_input_seq = torch.rand(size=embedding_shape, dtype=torch.float32) + + def attention_forward_pass(att_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q): + attention_layer = CausalSelfAttention( + n_head_q=n_head_q, + n_head_kv=n_head_kv, + n_embd=n_embd, + attention=AttentionConfig(attention_type=att_type, scaling_factor=3), + bias=False, + dropout=False, + block_size=block_size, + ) + output_tensor: torch.Tensor = attention_layer(embedded_input_seq) + return output_tensor + + if not successful: + with pytest.raises(Exception): + attention_forward_pass(att_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q) + else: + output_tensor = attention_forward_pass(att_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q) + assert output_tensor.size() == embedding_shape From 7dffa623a955955f84eed44a5f6e8e1d97eafa4e Mon Sep 17 00:00:00 2001 From: Luzian Hahn Date: Mon, 11 Mar 2024 11:13:52 +0100 Subject: [PATCH 6/9] fix: add verbose check for divisibility of K,V,Q matrix shapes --- src/modalities/models/gpt2/gpt2_model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index fbc4f66f..d64f2952 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -97,7 +97,14 @@ def __init__( block_size: int, ): super().__init__() - assert n_embd % n_head_q == 0 + assert n_embd % n_head_q == 0, ( + "Embeddings get passed to `n_head_q` different heads " + "and their dimension needs to be divisible by `n_head_q`." + ) + assert n_head_q % n_head_kv == 0, ( + "It is necessary to have `n_head_q` divisible by `n_head_kv`." + ' For more details, read about "Grouped Query Attention"' + ) # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear( in_features=n_embd, From 7eecb346f4841398cb9bb392f03d81d98e2adea1 Mon Sep 17 00:00:00 2001 From: Felix Stollenwerk Date: Mon, 11 Mar 2024 12:03:29 +0100 Subject: [PATCH 7/9] refactor: remove AttentionConfig --- config_files/config.yaml | 4 +-- .../config_example_mem_map_dataset.yaml | 4 +-- .../config_example_openGPTx_dataset.yaml | 4 +-- config_files/config_lorem_ipsum.yaml | 4 +-- examples/getting_started/example_config.yaml | 4 +-- .../library_usage/config_lorem_ipsum.yaml | 4 +-- src/modalities/models/gpt2/gpt2_model.py | 27 +++++++++---------- tests/checkpointing/gpt2_config.yaml | 4 +-- tests/models/test_attention.py | 16 ++++++----- 9 files changed, 29 insertions(+), 42 deletions(-) diff --git a/config_files/config.yaml b/config_files/config.yaml index cf9cc961..ad57ddef 100644 --- a/config_files/config.yaml +++ b/config_files/config.yaml @@ -148,9 +148,7 @@ model: n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index c5671efe..2c536d67 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -144,9 +144,7 @@ model: n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/config_files/config_example_openGPTx_dataset.yaml b/config_files/config_example_openGPTx_dataset.yaml index f817e4c1..8f3c6e35 100644 --- a/config_files/config_example_openGPTx_dataset.yaml +++ b/config_files/config_example_openGPTx_dataset.yaml @@ -151,9 +151,7 @@ model: n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: fused_swiglu epsilon: 1e-5 weight_init: diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index 07d71220..7e8ffd51 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -197,9 +197,7 @@ model: n_embd: 128 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: default_attention # pytorch_flash_attention - scaling_factor: 3 + attention_type: default_attention # pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/examples/getting_started/example_config.yaml b/examples/getting_started/example_config.yaml index 104815b4..3505b392 100644 --- a/examples/getting_started/example_config.yaml +++ b/examples/getting_started/example_config.yaml @@ -120,9 +120,7 @@ model: n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/examples/library_usage/config_lorem_ipsum.yaml b/examples/library_usage/config_lorem_ipsum.yaml index 61f34cb7..02eeca79 100644 --- a/examples/library_usage/config_lorem_ipsum.yaml +++ b/examples/library_usage/config_lorem_ipsum.yaml @@ -199,9 +199,7 @@ model: n_embd: 128 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: default_attention # pytorch_flash_attention - scaling_factor: 3 + attention_type: default_attention # pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index d64f2952..6ef55330 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -24,12 +24,6 @@ class ActivationType(str, Enum): FUSED_SWIGLU = "fused_swiglu" -class AttentionConfig(BaseModel): - attention_type: AttentionType - # TODO: Is this parameter really necessary? Or can it be always 3? - scaling_factor: Annotated[int, Field(strict=True, ge=1)] - - class WeightInitailizationConfig(BaseModel): mean: Annotated[float, Field(strict=True, ge=0.0)] std: Annotated[float, Field(strict=True, ge=0.0)] @@ -50,7 +44,7 @@ class GPT2LLMConfig(BaseModel): dropout: Annotated[float, Field(strict=True, ge=0.0)] bias: bool # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: AttentionConfig + attention_type: AttentionType activation: ActivationType epsilon: Annotated[float, Field(strict=True, ge=0.0)] weight_init: WeightInitailizationConfig @@ -91,7 +85,7 @@ def __init__( n_head_q: int, n_head_kv: int, n_embd: int, - attention: AttentionConfig, + attention_type: AttentionType, bias: bool, dropout: float, block_size: int, @@ -105,10 +99,15 @@ def __init__( "It is necessary to have `n_head_q` divisible by `n_head_kv`." ' For more details, read about "Grouped Query Attention"' ) + + _joint_projection_factor = ( + 3 # the projection matrices for query, key & values are concatenated to a single matrix + ) + # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear( in_features=n_embd, - out_features=attention.scaling_factor * n_embd, + out_features=_joint_projection_factor * n_embd, bias=bias, ) @@ -127,7 +126,7 @@ def __init__( self.n_embd = n_embd self.dropout = dropout - self.flash = attention.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION + self.flash = attention_type == AttentionType.PYTORCH_FLASH_ATTENTION if not self.flash: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -203,7 +202,7 @@ def __init__( activation: ActivationType, n_head_q: int, n_head_kv: int, - attention: AttentionConfig, + attention_type: AttentionType, dropout: float, block_size: int, ffn_hidden: int, @@ -214,7 +213,7 @@ def __init__( n_head_q=n_head_q, n_head_kv=n_head_kv, n_embd=n_embd, - attention=attention, + attention_type=attention_type, bias=bias, dropout=dropout, block_size=block_size, @@ -249,7 +248,7 @@ def __init__( ffn_hidden: int, dropout: float, bias: bool, - attention: AttentionConfig, + attention_type: AttentionType, activation: ActivationType, epsilon: float, weight_init: WeightInitailizationConfig, @@ -276,7 +275,7 @@ def __init__( activation=activation, n_head_q=n_head_q, n_head_kv=n_head_kv, - attention=attention, + attention_type=attention_type, dropout=dropout, block_size=block_size, ffn_hidden=ffn_hidden, diff --git a/tests/checkpointing/gpt2_config.yaml b/tests/checkpointing/gpt2_config.yaml index 1b2d5c80..8a4b4946 100644 --- a/tests/checkpointing/gpt2_config.yaml +++ b/tests/checkpointing/gpt2_config.yaml @@ -13,9 +13,7 @@ model: n_embd: 128 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: default_attention # pytorch_flash_attention - scaling_factor: 3 + attention_type: default_attention # pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/tests/models/test_attention.py b/tests/models/test_attention.py index f53acdd3..5a317baa 100644 --- a/tests/models/test_attention.py +++ b/tests/models/test_attention.py @@ -1,11 +1,11 @@ import pytest import torch -from modalities.models.gpt2.gpt2_model import AttentionConfig, AttentionType, CausalSelfAttention +from modalities.models.gpt2.gpt2_model import AttentionType, CausalSelfAttention @pytest.mark.parametrize( - "n_head_q, n_head_kv, n_embd, att_type, successful", + "n_head_q, n_head_kv, n_embd, attention_type, successful", [ (4, 4, 32, AttentionType.DEFAULT_ATTENTION, True), (8, 2, 32, AttentionType.DEFAULT_ATTENTION, True), @@ -13,18 +13,18 @@ (8, 3, 32, AttentionType.DEFAULT_ATTENTION, False), ], ) -def test_grouped_query_attention_forward(n_head_q, n_head_kv, n_embd, att_type, successful): +def test_grouped_query_attention_forward(n_head_q, n_head_kv, n_embd, attention_type, successful): batch_size = 2 block_size = 10 embedding_shape = (batch_size, block_size, n_embd) embedded_input_seq = torch.rand(size=embedding_shape, dtype=torch.float32) - def attention_forward_pass(att_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q): + def attention_forward_pass(attention_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q): attention_layer = CausalSelfAttention( n_head_q=n_head_q, n_head_kv=n_head_kv, n_embd=n_embd, - attention=AttentionConfig(attention_type=att_type, scaling_factor=3), + attention_type=attention_type, bias=False, dropout=False, block_size=block_size, @@ -34,7 +34,9 @@ def attention_forward_pass(att_type, block_size, embedded_input_seq, n_embd, n_h if not successful: with pytest.raises(Exception): - attention_forward_pass(att_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q) + attention_forward_pass(attention_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q) else: - output_tensor = attention_forward_pass(att_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q) + output_tensor = attention_forward_pass( + attention_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q + ) assert output_tensor.size() == embedding_shape From 66a788b01f993e336a3d63d58ddcd16d7ccc05a6 Mon Sep 17 00:00:00 2001 From: Luzian Hahn Date: Mon, 11 Mar 2024 14:04:16 +0100 Subject: [PATCH 8/9] debug: expanded KVs for GQA implementation --- src/modalities/models/gpt2/gpt2_model.py | 26 +++++++++++++++++++----- tests/models/test_attention.py | 7 ++++--- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 6ef55330..0deba8ba 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -140,11 +140,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head_q, C // self.n_head_q).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head_q, C // self.n_head_q).transpose(1, 2) # (B, nh_q, T, hs) + k = k.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh_kv, T, hs) + v = v.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh_kv, T, hs) - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(k, self.n_head_q // self.n_head_kv) # (B, nh_q, T, hs) + v = repeat_kv(v, self.n_head_q // self.n_head_kv) # (B, nh_q, T, hs) + + # causal self-attention; Self-attend: (B, nh_q, T, hs) x (B, nh_q, hs, T) -> (B, nh_q, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels y = torch.nn.functional.scaled_dot_product_attention( @@ -161,7 +165,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = att @ v # (B, nh_q, T, T) x (B, nh_q, T, hs) -> (B, nh_q, T, hs) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection @@ -327,3 +331,15 @@ def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self.forward_impl(inputs) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Source code adopted from + https://github.com/facebookresearch/llama/blob/9a001c7a0987afd7b8de94e538916eff8950a73a/llama/model.py#L164 + Adapted ordered dimensions and namings: bs=B, n_kv_heads=nh_kv, slen=T, head_dim=hs + """ + B, nh_kv, T, hs = x.shape + if n_rep == 1: + return x + return x[:, :, None, :, :].expand(B, nh_kv, n_rep, T, hs).reshape(B, nh_kv * n_rep, T, hs) diff --git a/tests/models/test_attention.py b/tests/models/test_attention.py index 5a317baa..c5657f8a 100644 --- a/tests/models/test_attention.py +++ b/tests/models/test_attention.py @@ -7,10 +7,11 @@ @pytest.mark.parametrize( "n_head_q, n_head_kv, n_embd, attention_type, successful", [ - (4, 4, 32, AttentionType.DEFAULT_ATTENTION, True), + # TODO: Flash Atttention + # (4, 4, 32, AttentionType.DEFAULT_ATTENTION, True), (8, 2, 32, AttentionType.DEFAULT_ATTENTION, True), - (9, 8, 32, AttentionType.DEFAULT_ATTENTION, False), - (8, 3, 32, AttentionType.DEFAULT_ATTENTION, False), + # (9, 8, 32, AttentionType.DEFAULT_ATTENTION, False), + # (8, 3, 32, AttentionType.DEFAULT_ATTENTION, False), ], ) def test_grouped_query_attention_forward(n_head_q, n_head_kv, n_embd, attention_type, successful): From cb4f93216d723f0427c27663ad41abf4d1177396 Mon Sep 17 00:00:00 2001 From: Felix Stollenwerk Date: Mon, 11 Mar 2024 15:13:42 +0100 Subject: [PATCH 9/9] fix: group query attention implementation --- src/modalities/models/gpt2/gpt2_model.py | 41 +++++++++++++++--------- tests/models/test_attention.py | 13 +++++--- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 0deba8ba..45d8f0d4 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -100,14 +100,22 @@ def __init__( ' For more details, read about "Grouped Query Attention"' ) - _joint_projection_factor = ( - 3 # the projection matrices for query, key & values are concatenated to a single matrix - ) + self.n_rep = n_head_q // n_head_kv - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear( + # query, key, value projections (separate) + self.q_attn = nn.Linear( + in_features=n_embd, + out_features=n_embd, + bias=bias, + ) + self.k_attn = nn.Linear( in_features=n_embd, - out_features=_joint_projection_factor * n_embd, + out_features=n_embd // self.n_rep, + bias=bias, + ) + self.v_attn = nn.Linear( + in_features=n_embd, + out_features=n_embd // self.n_rep, bias=bias, ) @@ -136,17 +144,20 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + B, T, _ = x.size() # batch size (B), sequence length (T), embedding dimensionality (self.n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - q = q.view(B, T, self.n_head_q, C // self.n_head_q).transpose(1, 2) # (B, nh_q, T, hs) - k = k.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh_kv, T, hs) - v = v.view(B, T, self.n_head_kv, C // self.n_head_kv).transpose(1, 2) # (B, nh_kv, T, hs) + q = self.q_attn(x) # (B, T, n_embd) + k = self.k_attn(x) # (B, T, n_embd / n_rep) + v = self.v_attn(x) # (B, T, n_embd / n_rep) + + q = q.view(B, T, self.n_head_q, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_q, T, hs) + k = k.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs) + v = v.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs) - # repeat k/v heads if n_kv_heads < n_heads - k = repeat_kv(k, self.n_head_q // self.n_head_kv) # (B, nh_q, T, hs) - v = repeat_kv(v, self.n_head_q // self.n_head_kv) # (B, nh_q, T, hs) + # repeat k/v heads if self.n_rep > 1 + k = repeat_kv(k, self.n_rep) # (B, nh_q, T, hs) + v = repeat_kv(v, self.n_rep) # (B, nh_q, T, hs) # causal self-attention; Self-attend: (B, nh_q, T, hs) x (B, nh_q, hs, T) -> (B, nh_q, T, T) if self.flash: @@ -166,7 +177,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh_q, T, T) x (B, nh_q, T, hs) -> (B, nh_q, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = y.transpose(1, 2).contiguous().view(B, T, self.n_embd) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) diff --git a/tests/models/test_attention.py b/tests/models/test_attention.py index c5657f8a..deb5f474 100644 --- a/tests/models/test_attention.py +++ b/tests/models/test_attention.py @@ -7,11 +7,16 @@ @pytest.mark.parametrize( "n_head_q, n_head_kv, n_embd, attention_type, successful", [ - # TODO: Flash Atttention - # (4, 4, 32, AttentionType.DEFAULT_ATTENTION, True), + # Flash Attention + (4, 4, 32, AttentionType.PYTORCH_FLASH_ATTENTION, True), + (8, 2, 32, AttentionType.PYTORCH_FLASH_ATTENTION, True), + (9, 8, 32, AttentionType.PYTORCH_FLASH_ATTENTION, False), + (8, 3, 32, AttentionType.PYTORCH_FLASH_ATTENTION, False), + # Default Attention + (4, 4, 32, AttentionType.DEFAULT_ATTENTION, True), (8, 2, 32, AttentionType.DEFAULT_ATTENTION, True), - # (9, 8, 32, AttentionType.DEFAULT_ATTENTION, False), - # (8, 3, 32, AttentionType.DEFAULT_ATTENTION, False), + (9, 8, 32, AttentionType.DEFAULT_ATTENTION, False), + (8, 3, 32, AttentionType.DEFAULT_ATTENTION, False), ], ) def test_grouped_query_attention_forward(n_head_q, n_head_kv, n_embd, attention_type, successful):