Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAM UTILISATION IS INCREASING RAPIDLY #145

Open
UTSAV-44 opened this issue Oct 1, 2024 · 2 comments
Open

RAM UTILISATION IS INCREASING RAPIDLY #145

UTSAV-44 opened this issue Oct 1, 2024 · 2 comments

Comments

@UTSAV-44
Copy link

UTSAV-44 commented Oct 1, 2024

For enforcing model to give response in json format, I am using ExLlamaV2TokenEnforcerFilter and ExLlamaV2PrefixFilter classes and appending to to filters list and passing as filters for generating output from model. As my usecase are limited so ,I thought of caching these both class by storing it in a dict and reusing it. But by doing this I observed that system ram utilization is increasing and after few iterations it leads to Out of Memory. Usually it takes 10-15 GB of system RAM but overtime the memory usage goes over 128 GB causing OOM. I tried getting the class which is creating this issue and found that ExLlamaV2TokenEnforcerFilter is not resetting some captured memory which is creating this problem.

We tried reinitalizing certain variables as below but it did not impact any memory reclaiming.

    self.universal_filter_map[use_case_id][0].token_sequence = []
    self.universal_filter_map[use_case_id][1].current_prefixes = set()
    self.universal_filter_map[use_case_id][1].current_str = ""
    self.universal_filter_map[use_case_id][1].prefix_strings = ["{", " {"]

I have logged this issue on ExllamaV2 ----- turboderp-org/exllamav2#639

I am sharing the code snippet for complete implementation.

def run_mihup_llm_inference(self, call_transcript: str, prompt_tuples: List[Tuple]) -> List[json]:
    self.cache.reset()
    common_transcript = format_transcript_text(call_transcript)
    prompts = []
    filters = []
    use_case_ids = []
    for upper_tuple in prompt_tuples:
        use_case_id = upper_tuple[1]
        use_case_ids.append(use_case_id)
        p = upper_tuple[0]
        prompt_str = p[0]
        prompt_question_combined = format_llama3_prompt(mihup_system_prompt, common_transcript + prompt_str)
        prompts.append(prompt_question_combined)
        filter_schema_parser = p[1]

        print_memory_usage()

        if use_case_id not in self.universal_filter_map:
            print("Not found in the cache memory")

            self.universal_filter_map[use_case_id] = [
                ExLlamaV2TokenEnforcerFilter(filter_schema_parser, self.tokenizer),
                ExLlamaV2PrefixFilter(self.model, self.tokenizer, ["{", " {"])
            ]
        else:
            self.universal_filter_map[use_case_id][0].token_sequence = []
            self.universal_filter_map[use_case_id][1].current_prefixes = set()
            self.universal_filter_map[use_case_id][1].current_str = ""
            self.universal_filter_map[use_case_id][1].prefix_strings = ["{", " {"]
            print("Found in the cache memory")

        print("length of map : ", len(self.universal_filter_map[use_case_id]))
        # Create fresh instances each time
        filters.append(self.universal_filter_map[use_case_id])

    # print(prompts)

    outputs = self.generator.generate(
        prompt=prompts,
        filters=filters,
        filter_prefer_eos=True,
        max_new_tokens=1536,
        add_bos=True,
        stop_conditions=get_llama3_stop_conditions(self.tokenizer),
        completion_only=True,
        encode_special_tokens=True,
    )

    final_output = []
    skipped_index = []
    for i in range(len(outputs)):
        output_json = None
        try:
            output_json = json.loads(outputs[i])
        except ValueError as e:
            skipped_index.append(i)
            print("error: ", outputs[i])
        if output_json is not None:
            final_output.append(json.loads(outputs[i]))

    # assert len(final_output) == len(use_case_ids)

    # gc.collect()
    print_memory_usage()

    use_case_id_key = "use_case_id"
    for idx in range(len(final_output)):
        if idx not in skipped_index:
            final_output[idx][use_case_id_key] = use_case_ids[idx]

    return final_output
@noamgat
Copy link
Owner

noamgat commented Oct 1, 2024

Hi,
LMFE by default caches all encountered prefixes. The prefix cache cannot be emptied if there are in-flight requests.
However, from time to time, you can clear it. If you want to clear it without modifying any code, you can do something like

filter = ExLlamaV2TokenEnforcerFilter(filter_schema_parser, self.tokenizer)
for i in range(10000):
      # use filter here
      if i % 100 == 0:
          filter.token_enforcer.prefix_states = {}   # this is the important line

Let me know if this helps you solve the problem

@UTSAV-44
Copy link
Author

UTSAV-44 commented Oct 4, 2024

Hello ,
I tried the suggested solution, but the RAM usage is still increasing, although at a slower rate.There is something which is still being cached.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants