Body

Rice CS lab makes leaps In memory compression for machine learning

Novel approach achieves 10,000X reduction in memory use

Rice CS Lab Makes Leaps In Memory Compression For Machine Learning

Rice University Associate Professor Anshu Shrivastava’s lab published papers this past summer at the International Conference on Learning Representations (ICLR 2024) and the International World Wide Web Conference (WWW 2024). 

Both papers detail advances in computing efficiency with the potential to drastically improve the field; one in machine learning (ML) models, and the other in graph neural networks (GNNs). 

Running ChatGPT On Your Phone

The first paper, In Defense Of Parameter Sharing For Model Compression, outlines a way to reduce the memory footprint of ML models, specifically large language models (LLMs) like ChatGPT. 

Parameters are, in a nutshell, the variables present in a machine learning model that it can use to adapt to new information.

LLMs are notorious for their high memory use. Lots of storage and computing power are needed to accommodate the vast number of variables an LLM has to calculate responses to on the fly. Reducing their memory usage would allow these ML models to run faster and more efficiently, priming them for better mainstream integration.

Methods currently used to mitigate that large memory footprint include getting rid of parameters (parameter shedding) and using less computing power for each parameter (quantizing). While somewhat effective, these methods have their limitations. 

Shrivastava’s lab instead uses a process called parameter sharing. Parameter sharing, according to Shrivastava, is assigning multiple values to single parameters, tying them together. He explained:

“Say we need 100 parameters but…I only have space for five parameters. Then essentially what I can do is say…the first 20 can only have one value. The second 20 can only have one value [and so on]. So I essentially still have 100 parameters, but they can only have five values. That’s basically the crux of what we have been showing for the last year…And in this paper, we were able to show that there is something fundamental about it, why this should always work and why this is completely different from [the usual methods].”

The larger the model, the more efficient this process becomes, exponentially reducing the amount of memory necessary to run an LLM (or any neural network) until it can, as Shrivastava stated, conceivably be run locally on a mobile device.

The larger the neural network, the more this approach helps. “If you take a small neural network you only probably get 1x, 2x; smaller compression,” explained Shrivastava, but “If you take large neural networks you get 100, 1000. Large neural networks with…100 gigs, that’s a lot of parameters and we got like 10,000X compression.” 

Reducing the computing power needed to run AI has other benefits. It puts the data closer to the end user, increasing data security. It also uses less energy, expelling less carbon into the atmosphere.

Improving Graph Neural Networks (GNNs)

Learning Scalable Structural Representations for Link Prediction, the second publication from Shrivastava’s lab, addresses a similar issue with graph neural networks. 

We see GNNs used every day, and some of their more common applications include: 

  • Social media networks
  • Discovering new drugs
  • Natural language processing in LLMs

Any collection of connected variables, Shrivastava explained, can be expressed as a graph. Link prediction is the ability of a neural network to predict connections between those variables.

Shrivastava used the scenario of predicting whether you will get diabetes as a real-world application of link prediction. To try and solve the problem of whether the disease will develop, you’d need to examine your DNA and the DNA of others genetically related to you like parents, siblings, and cousins. This interconnected group could be expressed, he explained, as a graph.  

Say all these people who share DNA live in the same neighborhood. If multiple people in that neighborhood developed diabetes, there may be some connection between them causing it, and it may be more likely you develop it as well. Determining that likelihood based on those connections is, in essence, link prediction. 

“You’re obviously connected to your parents and siblings but maybe also to your cousins and others, and now it's a complete, complex graph. And in this graph, if a lot of your neighborhood has diabetes, most likely you could,” Shrivastava explained.

Calculating a problem that complex requires capturing a richer representation of the data, which can be expensive. It would also be much too time-consuming to calculate manually. That’s why people have turned to GNNs as a tool to help. 

GNNs are, however, pretty bad at link prediction. Rice researchers hope to change that by using Bloom signatures to significantly improve the accuracy of GNN link prediction.

Bloom signatures are hash-based signatures akin to fingerprints. They serve as snapshots of the data used to compress the memory necessary to run link prediction problems.  

In the neighborhood analogy, it’s like taking a snapshot of the medical history of each household and using it to calculate a link instead of going person by person. Shrivastava’s lab saw significant, provable improvements in link prediction accuracy with this approach.

“We were able to solve [the problem] and improve the accuracy of the system by using Bloom signatures,” he said. Going forward, this research could theoretically contribute to faster, better neural networks used across a range of AI applications.