<?xml version="1.0" encoding="UTF-8" standalone="no"?><?xml-stylesheet href="http://www.blogger.com/styles/atom.css" type="text/css"?><feed xmlns="http://www.w3.org/2005/Atom" xmlns:blogger="http://schemas.google.com/blogger/2008" xmlns:gd="http://schemas.google.com/g/2005" xmlns:georss="http://www.georss.org/georss" xmlns:openSearch="http://a9.com/-/spec/opensearchrss/1.0/" xmlns:thr="http://purl.org/syndication/thread/1.0"><id>tag:blogger.com,1999:blog-8474926331452026626</id><updated>2025-10-07T22:55:30.582-07:00</updated><category term="Machine Learning"/><category term="Deep Learning"/><category term="Computer Vision"/><category term="Natural Language Processing"/><category term="Google Brain"/><category term="open source"/><category term="Research"/><category term="Publications"/><category term="conference"/><category term="Machine Perception"/><category term="Natural Language Understanding"/><category term="TensorFlow"/><category term="conferences"/><category term="datasets"/><category term="Education"/><category term="Neural Networks"/><category term="Health"/><category term="Reinforcement Learning"/><category term="University Relations"/><category term="Robotics"/><category term="AI"/><category term="Algorithms"/><category term="CVPR"/><category term="NLP"/><category term="Quantum Computing"/><category term="Multimodal Learning"/><category term="Speech"/><category term="Machine Intelligence"/><category term="Research Awards"/><category term="Computational Photography"/><category term="On-device Learning"/><category term="AI for Social Good"/><category term="Security and Privacy"/><category term="HCI"/><category term="Computer Science"/><category term="Quantum AI"/><category term="MOOC"/><category term="ICLR"/><category term="Machine Translation"/><category term="optimization"/><category term="Image Classification"/><category term="Self-Supervised Learning"/><category term="accessibility"/><category term="Pixel"/><category term="Visualization"/><category term="YouTube"/><category term="NeurIPS"/><category term="AutoML"/><category term="Hardware"/><category term="Responsible AI"/><category term="ACL"/><category term="Audio"/><category term="ICML"/><category term="ML"/><category term="Physics"/><category term="TPU"/><category term="Android"/><category term="EMNLP"/><category term="ML Fairness"/><category term="video"/><category term="Awards"/><category term="Search"/><category term="Structured Data"/><category term="Image Processing"/><category term="Information Retrieval"/><category term="Supervised Learning"/><category term="User Experience"/><category term="Collaboration"/><category term="Google Maps"/><category term="Graph Mining"/><category term="TTS"/><category term="distributed systems"/><category term="Automatic Speech Recognition"/><category term="Environment"/><category term="Google Accelerated Science"/><category term="Speech Recognition"/><category term="DeepMind"/><category term="Google Translate"/><category term="Large Language Models"/><category term="Video Analysis"/><category term="statistics"/><category term="2022 Year-in-Review"/><category term="ACM"/><category term="Chemistry"/><category term="Diversity"/><category term="Earth Engine"/><category term="K-12"/><category term="RAI-HCT Highlights"/><category term="UI"/><category term="Vision Research"/><category term="Acoustic Modeling"/><category term="Interspeech"/><category term="Systems"/><category term="Voice Search"/><category term="data science"/><category term="ph.d. fellowship"/><category term="Augmented Reality"/><category term="Cloud Computing"/><category term="Compression"/><category term="Differential Privacy"/><category term="Google Cloud Platform"/><category term="ICCV"/><category term="Machine Hearing"/><category term="NIPS"/><category term="Semi-supervised Learning"/><category term="Software"/><category term="Translate"/><category term="Unsupervised Learning"/><category term="crowd-sourcing"/><category term="grants"/><category term="market algorithms"/><category term="Faculty Summit"/><category term="Google Genomics"/><category term="Recommender Systems"/><category term="Semantic Models"/><category term="Art"/><category term="Biology"/><category term="Climate"/><category term="Course Builder"/><category term="Data Discovery"/><category term="Google Photos"/><category term="Google+"/><category term="PhD Fellowship"/><category term="Social Networks"/><category term="WWW"/><category term="ads"/><category term="renewable energy"/><category term="Computational Imaging"/><category term="Europe"/><category term="Expander"/><category term="Fusion Tables"/><category term="Google Books"/><category term="Graphs"/><category term="Kaggle"/><category term="Moore's Law"/><category term="Ngram"/><category term="Optical Character Recognition"/><category term="Virtual Reality"/><category term="Year in Review"/><category term="schema.org"/><category term="API"/><category term="Africa"/><category term="App Engine"/><category term="Gboard"/><category term="Generative AI"/><category term="Gmail"/><category term="Google Play Apps"/><category term="High Dynamic Range Imaging"/><category term="Image Annotation"/><category term="India"/><category term="Internet of Things"/><category term="NAACL"/><category term="Networks"/><category term="Style Transfer"/><category term="Weather"/><category term="economics"/><category term="internationalization"/><category term="publication"/><category term="resource optimization"/><category term="search ads"/><category term="wikipedia"/><category term="Adaptive Data Analysis"/><category term="Android Wear"/><category term="App Inventor"/><category term="China"/><category term="DeepDream"/><category term="EMEA"/><category term="Exacycle"/><category term="Genomics"/><category term="Google Docs"/><category term="Google Drive"/><category term="Google Science Fair"/><category term="Google Sheets"/><category term="Graph"/><category term="Inbox"/><category term="KDD"/><category term="Keyboard Input"/><category term="Labs"/><category term="Low-Light Photography"/><category term="MapReduce"/><category term="Policy"/><category term="Proposals"/><category term="TensorBoard"/><category term="VLDB"/><category term="electronics"/><category term="osdi"/><category term="patents"/><category term="trends"/><category term="April Fools"/><category term="Australia"/><category term="BigQuery"/><category term="CHI"/><category term="Cantonese"/><category term="Chrome"/><category term="Conservation"/><category term="Data Center"/><category term="ECCV"/><category term="Electronic Commerce and Algorithms"/><category term="Encryption"/><category term="Entity Salience"/><category term="Faculty Institute"/><category term="Flu Trends"/><category term="Google Cloud"/><category term="Google I/O"/><category term="Google Trips"/><category term="Google Voice Search"/><category term="Government"/><category term="High-Performance Computing"/><category term="ICSE"/><category term="IPython"/><category term="Journalism"/><category term="Klingon"/><category term="Korean"/><category term="Linear Optimization"/><category term="Magenta"/><category term="Market Research"/><category term="Mixed Reality"/><category term="Network Management"/><category term="Nexus"/><category term="Peer Review"/><category term="PhotoScan"/><category term="PiLab"/><category term="Professional Development"/><category term="Public Data Explorer"/><category term="SIGCOMM"/><category term="SIGMOD"/><category term="Site Reliability Engineering"/><category term="Sound Search"/><category term="TV"/><category term="UNIX"/><category term="Visiting Faculty"/><category term="Wiki"/><category term="adsense"/><category term="adwords"/><category term="correlate"/><category term="entities"/><category term="gamification"/><category term="jsm"/><category term="jsm2011"/><category term="localization"/><category term="materials science"/><category term="operating systems"/><category term="osdi10"/><title type="text">Google AI Blog</title><subtitle type="html">The latest news from Google AI.</subtitle><link href="http://blog.research.google/feeds/posts/default" rel="http://schemas.google.com/g/2005#feed" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default?alt=atom&amp;redirect=false" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/" rel="alternate" type="text/html"/><link href="http://pubsubhubbub.appspot.com/" rel="hub"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default?alt=atom&amp;start-index=26&amp;max-results=25&amp;redirect=false" rel="next" type="application/atom+xml"/><author><name>ewood</name><uri>http://www.blogger.com/profile/12341551220176883769</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><generator uri="http://www.blogger.com" version="7.00">Blogger</generator><openSearch:totalResults>1352</openSearch:totalResults><openSearch:startIndex>1</openSearch:startIndex><openSearch:itemsPerPage>25</openSearch:itemsPerPage><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-1569605132526995799</id><published>2024-03-29T11:03:00.000-07:00</published><updated>2024-03-29T11:03:10.261-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Climate"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="Weather"/><title type="text">Generative AI to quantify uncertainty in weather forecasting</title><content type="html">&lt;span class="byline-author"&gt;Posted by Lizao (Larry) Li, Software Engineer, and Rob Carver, Research Scientist, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglI5U51vvhkA4cAuVvMLn0TbbL5pdlFL-LO1sNnqLyUieA6A88I5HrhJlszxR1GKQqSK5wsdlATDKSy6EC1BsNF7tzS6oVlFLtau13mVFLk954nFu85HDMP3PrQboG4eXExEtUjEuDRFpcrMqE_F0ikSwXiWBECAfJiLbjr6h6523DROJkbC284xX35zC7/s1000/image3.gif" style="display: none;" /&gt;

&lt;p&gt;
Accurate weather forecasts can have a direct impact on people’s lives, from helping make routine decisions, like what to pack for a day’s activities, to informing urgent actions, for example, protecting people in the face of hazardous weather conditions. The importance of accurate and timely weather forecasts will only increase as the climate changes. Recognizing this, we at Google have been investing in weather and climate research to help ensure that the forecasting technology of tomorrow can meet the demand for reliable weather information. Some of our recent innovations include &lt;a href="https://blog.research.google/2023/11/metnet-3-state-of-art-neural-weather.html"&gt;MetNet-3&lt;/a&gt;, Google's high-resolution forecasts up to 24-hours into the future, and &lt;a href="https://deepmind.google/discover/blog/graphcast-ai-model-for-faster-and-more-accurate-global-weather-forecasting/"&gt;GraphCast&lt;/a&gt;, a weather model that can predict weather up to 10 days ahead.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt; 

&lt;p&gt;
Weather is inherently stochastic. To quantify the uncertainty, traditional methods rely on physics-based simulation to generate an ensemble of forecasts. However, it is computationally costly to generate a large ensemble so that rare and extreme weather events can be discerned and characterized accurately.  
&lt;/p&gt;
&lt;p&gt;
With that in mind, we are excited to announce our latest innovation designed to accelerate progress in weather forecasting, &lt;a href="https://www.science.org/doi/10.1126/sciadv.adk4489"&gt;Scalable Ensemble Envelope Diffusion Sampler&lt;/a&gt; (SEEDS), recently published in &lt;em&gt;&lt;a href="https://www.science.org/journal/sciadv"&gt;Science Advances&lt;/a&gt;&lt;/em&gt;. SEEDS is a generative AI model that can efficiently generate ensembles of weather forecasts &lt;em&gt;at scale &lt;/em&gt;at a small fraction of the cost of traditional physics-based forecasting models. This technology opens up novel opportunities for weather and climate science, and it represents one of the first applications to weather and climate forecasting of probabilistic diffusion models, a generative AI technology behind recent advances in media generation.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;The need for probabilistic forecasts: the butterfly effect&lt;/h2&gt;

&lt;p&gt;
In December 1972, at the &lt;a href="https://www.aaas.org/"&gt;American Association for the Advancement of Science&lt;/a&gt; meeting in Washington, D.C., MIT meteorology professor &lt;a href="https://en.wikipedia.org/wiki/Edward_Norton_Lorenz"&gt;Ed Lorenz&lt;/a&gt; gave a talk entitled, “Does the Flap of a Butterfly's Wings in Brazil Set Off a Tornado in Texas?” which contributed to the term “&lt;a href="https://en.wikipedia.org/wiki/Butterfly_effect"&gt;butterfly effect&lt;/a&gt;”. He was building on his earlier, landmark 1963 paper where he examined the feasibility of “very-long-range weather prediction” and described how errors in initial conditions grow exponentially when integrated in time with numerical weather prediction models. This exponential error growth, known as chaos, results in a deterministic predictability limit that restricts the use of individual forecasts in decision making, because they do not quantify the inherent uncertainty of weather conditions. This is particularly problematic when forecasting extreme weather events, such as hurricanes, heatwaves, or floods.
&lt;/p&gt;
&lt;p&gt;
Recognizing the limitations of deterministic forecasts, weather agencies around the world issue &lt;em&gt;probabilistic forecasts&lt;/em&gt;. Such forecasts are based on ensembles of deterministic forecasts, each of which is generated by including synthetic noise in the initial conditions and stochasticity in the physical processes. Leveraging the fast error growth rate in weather models, the forecasts in an ensemble are purposefully different: the initial uncertainties are tuned to generate runs that are as different as possible and the stochastic processes in the weather model introduce additional differences during the model run. The error growth is mitigated by averaging all the forecasts in the ensemble and the variability in the ensemble of forecasts quantifies the uncertainty of the weather conditions.
&lt;/p&gt;
&lt;p&gt;
While effective, generating these probabilistic forecasts is computationally costly. They require running highly complex numerical weather models on massive supercomputers multiple times. Consequently, many operational weather forecasts can only afford to generate ~10–50 ensemble members for each forecast cycle. This is a problem for users concerned with the likelihood of rare but high-impact weather events, which typically require much larger ensembles to assess beyond a few days. For instance, one would need a 10,000-member ensemble to forecast the likelihood of events with 1% probability of occurrence with a relative error less than 10%. Quantifying the probability of such extreme events could be useful, for example, for emergency management preparation or for energy traders.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;SEEDS: AI-enabled advances&lt;/h2&gt;

&lt;p&gt;
In the aforementioned &lt;a href="https://www.science.org/doi/10.1126/sciadv.adk4489"&gt;paper&lt;/a&gt;, we present the Scalable Ensemble Envelope Diffusion Sampler (SEEDS), a generative AI technology for weather forecast ensemble generation. SEEDS is based on &lt;a href="https://blog.research.google/2021/07/high-fidelity-image-generation-using.html"&gt;denoising diffusion probabilistic&lt;/a&gt; models, a state-of-the-art generative AI method pioneered in part by Google Research.
&lt;/p&gt;
&lt;p&gt;
SEEDS can generate a large ensemble conditioned on as few as one or two forecasts from an operational numerical weather prediction system. The generated ensembles not only yield plausible real-weather–like forecasts but also match or exceed physics-based ensembles in skill metrics such as the &lt;a href="https://www.jstor.org/stable/26201352"&gt;rank histogram&lt;/a&gt;, the &lt;a href="https://en.wikipedia.org/wiki/Root-mean-square_deviation"&gt;root-mean-squared error&lt;/a&gt; (RMSE), and the &lt;a href="https://www.tandfonline.com/doi/abs/10.1198/016214506000001437"&gt;continuous ranked probability score&lt;/a&gt; (CRPS). In particular, the generated ensembles assign more accurate likelihoods to the tail of the forecast distribution, such as ±2σ and ±3σ weather events. Most importantly, the computational cost of the model is negligible when compared to the hours of computational time needed by supercomputers to make a forecast. It has a throughput of 256 ensemble members (at 2° resolution) per 3 minutes on Google Cloud TPUv3-32 instances and can easily scale to higher throughput by deploying more accelerators. 
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglI5U51vvhkA4cAuVvMLn0TbbL5pdlFL-LO1sNnqLyUieA6A88I5HrhJlszxR1GKQqSK5wsdlATDKSy6EC1BsNF7tzS6oVlFLtau13mVFLk954nFu85HDMP3PrQboG4eXExEtUjEuDRFpcrMqE_F0ikSwXiWBECAfJiLbjr6h6523DROJkbC284xX35zC7/s1000/image3.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="470" data-original-width="1000" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglI5U51vvhkA4cAuVvMLn0TbbL5pdlFL-LO1sNnqLyUieA6A88I5HrhJlszxR1GKQqSK5wsdlATDKSy6EC1BsNF7tzS6oVlFLtau13mVFLk954nFu85HDMP3PrQboG4eXExEtUjEuDRFpcrMqE_F0ikSwXiWBECAfJiLbjr6h6523DROJkbC284xX35zC7/s16000/image3.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;SEEDS generates an order-of-magnitude more samples to in-fill distributions of weather patterns.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;

&lt;h2&gt;Generating plausible weather forecasts&lt;/h2&gt;


&lt;p&gt;
Generative AI is known to generate very detailed images and videos. This property is especially useful for generating ensemble forecasts that are consistent with plausible weather patterns, which ultimately result in the most added value for downstream applications.  As Lorenz points out, “The [weather forecast] maps which they produce should look like real weather maps." The figure below contrasts the forecasts from SEEDS to those from the operational U.S. weather prediction system (&lt;a href="https://www.emc.ncep.noaa.gov/emc/pages/numerical_forecast_systems/gefs.php"&gt;Global Ensemble Forecast System&lt;/a&gt;, GEFS) for a particular date during the &lt;a href="https://en.wikipedia.org/wiki/2022_European_heatwaves"&gt;2022 European heat waves&lt;/a&gt;. We also compare the results to the forecasts from a Gaussian model that predicts the univariate mean and standard deviation of each atmospheric field at each location, a common and computationally efficient but less sophisticated data-driven approach. This Gaussian model is meant to characterize the output of pointwise post-processing, which ignores correlations and treats each grid point as an independent random variable. In contrast, a real weather map would have detailed &lt;em&gt;correlational&lt;/em&gt; structures. 
&lt;/p&gt;
&lt;p&gt;
Because SEEDS directly models the joint distribution of the atmospheric state, it realistically captures both the spatial covariance and the correlation between mid-tropospheric geopotential and mean sea level pressure, both of which are closely related and are commonly used by weather forecasters for evaluation and verification of forecasts. Gradients in the mean sea level pressure are what drive winds at the surface, while gradients in mid-tropospheric geopotential create upper-level winds that move large-scale weather patterns. 
&lt;/p&gt;
&lt;p&gt;
The generated samples from SEEDS shown in the figure below (frames Ca–Ch) display a geopotential trough west of Portugal with spatial structure similar to that found in the operational U.S. forecasts or the reanalysis based on observations. Although the Gaussian model predicts the marginal univariate distributions adequately, it fails to capture cross-field or spatial correlations. This hinders the assessment of the effects that these anomalies may have on hot air intrusions from North Africa, which can exacerbate heat waves over Europe.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgQE94TGK404COMAKKxaPwUO9bD8gIzQfu6A0u5c-5xbGKhlUtBW_0KAj-Ur8kpgt5_f-IjAuFzeecpRbbWVujZNQVExTsl0UuDRtOb84Y8uFWc4G1UYYZos6gLVtIHQ3AZ7ojRqoMSmt8IHdTOSx365AaoNyUfNMi1ksC0Wh_axeD_THB6sOmnZZHhrvHQ/s1999/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1999" data-original-width="1675" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgQE94TGK404COMAKKxaPwUO9bD8gIzQfu6A0u5c-5xbGKhlUtBW_0KAj-Ur8kpgt5_f-IjAuFzeecpRbbWVujZNQVExTsl0UuDRtOb84Y8uFWc4G1UYYZos6gLVtIHQ3AZ7ojRqoMSmt8IHdTOSx365AaoNyUfNMi1ksC0Wh_axeD_THB6sOmnZZHhrvHQ/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Stamp maps over Europe on 2022/07/14 at 0:00 UTC. The contours are for the mean sea level pressure (dashed lines mark isobars below 1010 hPa) while the heatmap depicts the geopotential height at the 500 hPa pressure level. (A) The&amp;nbsp;&lt;a href="https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5"&gt;ERA5&lt;/a&gt;&amp;nbsp;reanalysis, a proxy for real observations. (Ba-Bb) 2 members from the 7-day U.S. operational forecasts used as seeds to our model. (Ca-Ch) 8 samples drawn from SEEDS. (Da-Dh) 8 non-seeding members from the 7-day U.S. operational ensemble forecast. (Ea-Ed) 4 samples from a pointwise Gaussian model parameterized by the mean and variance of the entire U.S. operational ensemble.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Covering extreme events more accurately  &lt;/h2&gt;

&lt;p&gt;
Below we show the joint distributions of temperature at 2 meters and total column water vapor near Lisbon during the extreme heat event on 2022/07/14, at 1:00 local time. We used the 7-day forecasts issued on 2022/07/07. For each plot, we generate 16,384-member ensembles with SEEDS. The observed weather event from ERA5 is denoted by the star. The operational ensemble is also shown, with squares denoting the forecasts used to seed the generated ensembles, and triangles denoting the rest of ensemble members.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgVbbmrrrJ5L1NVb_O7WPUD-d6ULlTJTSns6ZaqjxOqZ4YAi4zOiT72rfMBf8EGTe0kdofIrWAMESq1m2v9IBjnd_k6UAIDM7LvhbxdVr41FOQ0fqkKeERF_QqXbxs94qKLdMxR-A7Hbxkjd4zZn07AlldAsuvn7jsYCu-V3UVAatovY1ELbrcLQz5I1ppX/s1999/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="941" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgVbbmrrrJ5L1NVb_O7WPUD-d6ULlTJTSns6ZaqjxOqZ4YAi4zOiT72rfMBf8EGTe0kdofIrWAMESq1m2v9IBjnd_k6UAIDM7LvhbxdVr41FOQ0fqkKeERF_QqXbxs94qKLdMxR-A7Hbxkjd4zZn07AlldAsuvn7jsYCu-V3UVAatovY1ELbrcLQz5I1ppX/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;SEEDS provides better statistical coverage of the 2022/07/14 European extreme heat event, denoted by the brown star . Each plot shows the values of the total column-integrated water vapor (TCVW) vs. temperature over a grid point near Lisbon, Portugal from 16,384 samples generated by our models, shown as green dots, conditioned on 2 seeds (blue squares) taken from the 7-day U.S. operational ensemble forecasts (denoted by the sparser brown triangles). The valid forecast time is 1:00 local time. The solid contour levels correspond to iso-proportions of the kernel density of SEEDS, with the outermost one encircling 95% of the mass and 11.875% between each level.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;p&gt;
According to the U.S. operational ensemble, the observed event was so unlikely seven days prior that none of its 31 members predicted near-surface temperatures as warm as those observed. Indeed, the event probability computed from a Gaussian kernel density estimate is lower than 1%, which means that ensembles with less than 100 members are unlikely to contain forecasts as extreme as this event. In contrast, the SEEDS ensembles are able to extrapolate from the two seeding forecasts, providing an envelope of possible weather states with much better statistical coverage of the event. This allows both quantifying the probability of the event taking place and sampling weather regimes under which it would occur. Specifically, our highly scalable generative approach enables the creation of very large ensembles that can characterize very rare events by providing samples of weather states exceeding a given threshold for any user-defined diagnostic.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Conclusion and future outlook&lt;/h2&gt;

&lt;p&gt;
SEEDS leverages the power of generative AI to produce ensemble forecasts comparable to those from the operational U.S. forecast system, but at an accelerated pace. The results reported in this paper need only 2 seeding forecasts from the operational system, which generates 31 forecasts in its current version. This leads to a hybrid forecasting system where a few weather trajectories computed with a physics-based model are used to seed a diffusion model that can generate additional forecasts much more efficiently. This methodology provides an alternative to the current operational weather forecasting paradigm, where the computational resources saved by the statistical emulator could be allocated to increasing the resolution of the physics-based model or issuing forecasts more frequently.
&lt;/p&gt;
&lt;p&gt;
We believe that SEEDS represents just one of the many ways that AI will accelerate progress in operational numerical weather prediction in coming years. We hope this demonstration of the  utility of generative AI for weather forecast emulation and post-processing will spur its application in research areas such as climate risk assessment, where generating a large number of ensembles of climate projections is crucial to accurately quantifying the uncertainty about future climate.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Acknowledgements&lt;/h2&gt;

&lt;p&gt;
&lt;em&gt;All SEEDS authors, Lizao Li, Rob Carver, Ignacio Lopez-Gomez, Fei Sha and John Anderson, co-authored this blog post, with Carla Bromberg as Program Lead. We also thank Tom Small who designed the animation. Our colleagues at Google Research have provided invaluable advice to the SEEDS work. Among them, we thank Leonardo Zepeda-Núñez, Zhong Yi Wan, Stephan Rasp, Stephan Hoyer, and Tapio Schneider for their inputs and useful discussion. We thank Tyler Russell for additional technical program management, as well as Alex Merose for data coordination and support. We also thank Cenk Gazen, Shreya Agrawal, and Jason Hickey for discussions in the early stage of the SEEDS work. &lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/1569605132526995799/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/generative-ai-to-quantify-uncertainty.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1569605132526995799" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1569605132526995799" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/generative-ai-to-quantify-uncertainty.html" rel="alternate" title="Generative AI to quantify uncertainty in weather forecasting" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglI5U51vvhkA4cAuVvMLn0TbbL5pdlFL-LO1sNnqLyUieA6A88I5HrhJlszxR1GKQqSK5wsdlATDKSy6EC1BsNF7tzS6oVlFLtau13mVFLk954nFu85HDMP3PrQboG4eXExEtUjEuDRFpcrMqE_F0ikSwXiWBECAfJiLbjr6h6523DROJkbC284xX35zC7/s72-c/image3.gif" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-1799535679952845079</id><published>2024-03-28T13:53:00.000-07:00</published><updated>2024-03-29T12:00:03.604-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="Neural Networks"/><category scheme="http://www.blogger.com/atom/ns#" term="open source"/><category scheme="http://www.blogger.com/atom/ns#" term="statistics"/><title type="text">AutoBNN: Probabilistic time series forecasting with compositional bayesian neural networks</title><content type="html">&lt;span class="byline-author"&gt;Posted by Urs Köster, Software Engineer, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgd5Wc54p1HvgIokpazxDsMo1u6i9wg3ovpNOiFc4-wYwebETvjs9-hm2wxZ4osNbBAxhet8To3hwGg-whFScksHQB_BP1kS4Z8Cu7FQT2bjVtJl4trPid-OxCyYocwyRTN66tuvAedu9z0FepBg4zZvmLbLxY6uuib8p5jVH2kfb3RxT_HMABsKMXuSFXr/s320/AutoBNN.jpg" style="display: none;" /&gt;

&lt;p&gt;
&lt;a href="https://en.wikipedia.org/wiki/Time_series"&gt;Time series&lt;/a&gt; problems are ubiquitous, from forecasting weather and traffic patterns to understanding economic trends. &lt;a href="https://en.wikipedia.org/wiki/Bayesian_inference"&gt;Bayesian&lt;/a&gt; approaches start with an assumption about the data's patterns (prior probability), collecting evidence (e.g., new time series data), and continuously updating that assumption to form a posterior probability distribution. Traditional Bayesian approaches like &lt;a href="https://gaussianprocess.org/gpml/"&gt;Gaussian processes&lt;/a&gt; (GPs) and &lt;a href="https://blog.tensorflow.org/2019/03/structural-time-series-modeling-in.html"&gt;Structural Time Series&lt;/a&gt; are extensively used for modeling time series data, e.g., the commonly used &lt;a href="https://gml.noaa.gov/ccgg/trends/"&gt;Mauna Loa CO2&lt;/a&gt; dataset. However, they often rely on domain experts to painstakingly select appropriate model components and may be computationally expensive. Alternatives such as neural networks lack interpretability, making it difficult to understand how they generate forecasts, and don't produce reliable confidence intervals. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
To that end, we introduce &lt;a href="https://github.com/tensorflow/probability/tree/main/spinoffs/autobnn"&gt;AutoBNN&lt;/a&gt;, a new open-source package written in &lt;a href="https://github.com/google/jax"&gt;JAX&lt;/a&gt;. AutoBNN automates the discovery of interpretable time series forecasting models, provides high-quality uncertainty estimates, and scales effectively for use on large datasets. We describe how AutoBNN combines the interpretability of traditional probabilistic approaches with the scalability and flexibility of neural networks.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;AutoBNN&lt;/h2&gt;


&lt;p&gt;
AutoBNN is based on a &lt;a href="https://proceedings.mlr.press/v28/duvenaud13.html"&gt;line&lt;/a&gt; &lt;a href="https://royalsocietypublishing.org/doi/10.1098/rsta.2011.0550"&gt;of&lt;/a&gt; &lt;a href="https://proceedings.mlr.press/v202/saad23a.html"&gt;research&lt;/a&gt; that over the past decade has yielded improved predictive accuracy by modeling time series using GPs with learned &lt;a href="https://www.cs.toronto.edu/~duvenaud/cookbook/"&gt;kernel&lt;/a&gt; structures. The kernel function of a GP encodes assumptions about the function being modeled, such as the presence of trends, periodicity or noise.  With learned GP kernels, the kernel function is defined compositionally: it is either a base kernel (such as &lt;code&gt;Linear&lt;/code&gt;, &lt;code&gt;Quadratic&lt;/code&gt;, &lt;code&gt;Periodic&lt;/code&gt;, &lt;code&gt;&lt;a href="https://en.wikipedia.org/wiki/Mat%C3%A9rn_covariance_function"&gt;Matérn&lt;/a&gt;&lt;/code&gt; or &lt;code&gt;ExponentiatedQuadratic&lt;/code&gt;) or a composite that combines two or more kernel functions using operators such as &lt;code&gt;Addition&lt;/code&gt;, &lt;code&gt;Multiplication&lt;/code&gt;, or &lt;code&gt;&lt;a href="https://icml.cc/Conferences/2010/papers/170.pdf"&gt;ChangePoint&lt;/a&gt;&lt;/code&gt;. This compositional kernel structure serves two related purposes. First, it is simple enough that a user who is an expert about their data, but not necessarily about GPs, can construct a reasonable prior for their time series. Second, techniques like &lt;a href="https://www.stats.ox.ac.uk/~doucet/doucet_defreitas_gordon_smcbookintro.pdf"&gt;Sequential Monte Carlo&lt;/a&gt; can be used for discrete searches over small structures and can output interpretable results.&lt;/p&gt;

&lt;p&gt;
AutoBNN improves upon these ideas, replacing the GP with &lt;a href="https://www.cs.toronto.edu/~duvenaud/distill_bayes_net/public/"&gt;Bayesian neural networks&lt;/a&gt; (BNNs) while retaining the compositional kernel structure. A BNN is a neural network with a probability distribution over weights rather than a fixed set of weights. This induces a distribution over outputs, capturing uncertainty in the predictions. BNNs bring the following advantages over GPs: First, training large GPs is computationally expensive, and traditional training algorithms scale as the cube of the number of data points in the time series. In contrast, for a fixed width, training a BNN will often be approximately linear in the number of data points. Second, BNNs lend themselves better to GPU and &lt;a href="https://cloud.google.com/tpu?hl=en"&gt;TPU&lt;/a&gt; hardware acceleration than GP training operations. Third, compositional BNNs can be easily combined with &lt;a href="https://arxiv.org/abs/2007.06823"&gt;traditional deep BNNs&lt;/a&gt;, which have the ability to do feature discovery. One could imagine "hybrid" architectures, in which users specify a top-level structure of &lt;code&gt;Add&lt;/code&gt;(&lt;code&gt;Linear&lt;/code&gt;, &lt;code&gt;Periodic&lt;/code&gt;, &lt;code&gt;Deep&lt;/code&gt;), and the deep BNN is left to learn the contributions from potentially high-dimensional covariate information.
&lt;/p&gt;

&lt;p&gt;
How might one translate a GP with compositional kernels into a BNN then? A single layer neural network will typically converge to a GP as the number of neurons (or "width") &lt;a href="https://link.springer.com/chapter/10.1007/978-1-4612-0745-0_2"&gt;goes to infinity&lt;/a&gt;. More recently, researchers have &lt;a href="https://openreview.net/forum?id=gRwh5HkdaTm"&gt;discovered&lt;/a&gt; a correspondence in the other direction — many popular GP &lt;a href="https://www.cs.toronto.edu/~duvenaud/cookbook/"&gt;kernels&lt;/a&gt; (such as &lt;code&gt;Matern&lt;/code&gt;, &lt;code&gt;ExponentiatedQuadratic&lt;/code&gt;, &lt;code&gt;Polynomial&lt;/code&gt; or &lt;code&gt;Periodic&lt;/code&gt;) can be obtained as infinite-width BNNs with appropriately chosen activation functions and weight distributions. Furthermore, these BNNs remain close to the corresponding GP even when the width is very much less than infinite. For example, the figures below show the difference in the &lt;a href="https://en.wikipedia.org/wiki/Covariance_matrix#:~:text=In%20probability%20theory%20and%20statistics,of%20a%20given%20random%20vector"&gt;covariance&lt;/a&gt; between pairs of observations, and &lt;a href="https://en.wikipedia.org/wiki/Kriging"&gt;regression&lt;/a&gt; results of the true GPs and their corresponding width-10 neural network versions.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiHJ7hHI33S76Id3RrWCYezQKky9oELeuWf_CTm7GYadxpV7-B9GSQKCZgTmVQABi9zpWcEK8uvTYITyX2_jcbv_qF-eGv2C1QkU9oDCAS09FfoCne81yEAqC5moTNIqsn05aHfWNr8uy48N3UfV_tRGOyGrrQvB8l7RegzAq5_LNK2W8_Y_gSavdfi5aDI/s1350/image3.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="598" data-original-width="1350" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiHJ7hHI33S76Id3RrWCYezQKky9oELeuWf_CTm7GYadxpV7-B9GSQKCZgTmVQABi9zpWcEK8uvTYITyX2_jcbv_qF-eGv2C1QkU9oDCAS09FfoCne81yEAqC5moTNIqsn05aHfWNr8uy48N3UfV_tRGOyGrrQvB8l7RegzAq5_LNK2W8_Y_gSavdfi5aDI/s16000/image3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparison of &lt;a href="https://en.wikipedia.org/wiki/Gram_matrix"&gt;Gram matrices&lt;/a&gt; between true GP kernels (top row) and their width 10 neural network approximations (bottom row).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhoidYqlAK2J1n4y71Qn-WuIcmaxGI9ynwSjtHAvyukuY_q5QcX4pVEheX2pwMxIhkAu7_OZR-0s7N7e-cU-caromj1wntP7E1txZfxHqh2yeTedusA90k9hFZ2yvzEZmC2QlPyR7trgVuMro-MoicBxpAbrkQXs2F9h1uux3AXzUENmJ0NA8Ch9dyICT15/s1328/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="586" data-original-width="1328" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhoidYqlAK2J1n4y71Qn-WuIcmaxGI9ynwSjtHAvyukuY_q5QcX4pVEheX2pwMxIhkAu7_OZR-0s7N7e-cU-caromj1wntP7E1txZfxHqh2yeTedusA90k9hFZ2yvzEZmC2QlPyR7trgVuMro-MoicBxpAbrkQXs2F9h1uux3AXzUENmJ0NA8Ch9dyICT15/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparison of regression results between true GP kernels (top row) and their width 10 neural network approximations (bottom row).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
Finally, the translation is completed with &lt;a href="https://arxiv.org/abs/1905.06076"&gt;BNN analogues&lt;/a&gt; of the &lt;code&gt;Addition&lt;/code&gt; and &lt;code&gt;Multiplication&lt;/code&gt; operators over GPs, and input warping to produce periodic kernels. BNN addition is straightforwardly given by adding the outputs of the component BNNs. BNN multiplication is achieved by multiplying the activations of the hidden layers of the BNNs and then applying a shared dense layer. We are therefore limited to only multiplying BNNs with the same hidden width.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Using AutoBNN&lt;/h2&gt;


&lt;p&gt;
The AutoBNN &lt;a href="https://github.com/tensorflow/probability/tree/main/spinoffs/autobnn"&gt;package&lt;/a&gt; is available within &lt;a href="https://www.tensorflow.org/probability"&gt;Tensorflow Probability&lt;/a&gt;. It is implemented in &lt;a href="https://github.com/google/jax"&gt;JAX&lt;/a&gt; and uses the &lt;a href="https://github.com/google/flax"&gt;flax.linen&lt;/a&gt; neural network library. It implements all of the base kernels and operators discussed so far (&lt;code&gt;Linear&lt;/code&gt;, &lt;code&gt;Quadratic&lt;/code&gt;, &lt;code&gt;Matern&lt;/code&gt;, &lt;code&gt;ExponentiatedQuadratic&lt;/code&gt;, &lt;code&gt;Periodic&lt;/code&gt;, &lt;code&gt;Addition&lt;/code&gt;, &lt;code&gt;Multiplication&lt;/code&gt;) plus one new kernel and three new operators:  
&lt;/p&gt;

&lt;ul&gt;

&lt;li&gt;a &lt;code&gt;OneLayer&lt;/code&gt; kernel, a single hidden layer &lt;a href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)"&gt;ReLU&lt;/a&gt; BNN,

&lt;/li&gt;&lt;li&gt;a &lt;code&gt;&lt;a href="https://icml.cc/Conferences/2010/papers/170.pdf"&gt;ChangePoint&lt;/a&gt;&lt;/code&gt; operator that allows smoothly switching between two kernels,

&lt;/li&gt;&lt;li&gt;a &lt;code&gt;LearnableChangePoint&lt;/code&gt; operator which is the same as &lt;code&gt;ChangePoint&lt;/code&gt; except position and slope are given prior distributions and can be learnt from the data, and

&lt;/li&gt;&lt;li&gt;a &lt;code&gt;WeightedSum&lt;/code&gt; operator.
&lt;/li&gt;
&lt;/ul&gt;


&lt;p&gt;
&lt;code&gt;WeightedSum&lt;/code&gt; combines two or more BNNs with learnable mixing weights, where the learnable weights follow a &lt;a href="https://en.wikipedia.org/wiki/Dirichlet_distribution"&gt;Dirichlet prior&lt;/a&gt;. By default, a flat Dirichlet distribution with concentration 1.0 is used.
&lt;/p&gt;

&lt;p&gt;
&lt;code&gt;WeightedSums&lt;/code&gt; allow a "soft" version of structure discovery, i.e., training a linear combination of many possible models at once. In contrast to structure discovery with discrete structures, such as in &lt;a href="https://proceedings.mlr.press/v202/saad23a.html"&gt;AutoGP&lt;/a&gt;, this allows us to use standard gradient methods to learn structures, rather than using expensive discrete optimization. Instead of evaluating potential combinatorial structures in series, WeightedSum allows us to evaluate them in parallel. 
&lt;/p&gt;

&lt;p&gt;
To easily enable exploration, AutoBNN defines a &lt;a href="https://github.com/tensorflow/probability/blob/main/spinoffs/autobnn/autobnn/models.py"&gt;number of model structures&lt;/a&gt; that contain either top-level or internal &lt;code&gt;WeightedSums&lt;/code&gt;. The names of these models can be used as the first parameter in any of the &lt;a href="https://github.com/tensorflow/probability/blob/main/spinoffs/autobnn/autobnn/estimators.py"&gt;estimator&lt;/a&gt; constructors, and include things like &lt;code&gt;&lt;a href="https://github.com/tensorflow/probability/blob/main/spinoffs/autobnn/autobnn/models.py#L133"&gt;sum_of_stumps&lt;/a&gt;&lt;/code&gt; (the &lt;code&gt;WeightedSum&lt;/code&gt; over all the base kernels) and &lt;code&gt;sum_of_shallow&lt;/code&gt; (which adds all possible combinations of base kernels with all operators).&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgNmWFuh7tVRkaF9o4nr3Fu7B2CNmXpDkGx8_9fMASh2olAfjlSdBXLj-0cgh7UIVWs6fHlNyyCvRPA_vc4eq-3lixkC2VXzCeSCZBFDHIc1qYfK53EwEdngf1KykzCfpPiIg3YoN46AZkBSSmCLrgPXX84PaZp_cxLrNnmojz2S6pLOCmTTT2niRi8Qfe5/s1389/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="255" data-original-width="1389" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgNmWFuh7tVRkaF9o4nr3Fu7B2CNmXpDkGx8_9fMASh2olAfjlSdBXLj-0cgh7UIVWs6fHlNyyCvRPA_vc4eq-3lixkC2VXzCeSCZBFDHIc1qYfK53EwEdngf1KykzCfpPiIg3YoN46AZkBSSmCLrgPXX84PaZp_cxLrNnmojz2S6pLOCmTTT2niRi8Qfe5/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Illustration of the &lt;code&gt;sum_of_stumps&lt;/code&gt; model. The bars in the top row show the amount by which each base kernel contributes, and the bottom row shows the function represented by the base kernel. The resulting weighted sum is shown on the right.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
The figure below demonstrates the technique of structure discovery on the N374 (a time series of yearly financial data starting from 1949) from the &lt;a href="https://forecasters.org/resources/time-series-data/m3-competition/"&gt;M3&lt;/a&gt; dataset. The six base structures were &lt;code&gt;ExponentiatedQuadratic&lt;/code&gt; (which is the same as the Radial Basis Function kernel, or &lt;a href="https://en.wikipedia.org/wiki/Radial_basis_function_kernel"&gt;RBF&lt;/a&gt; for short), &lt;code&gt;Matern&lt;/code&gt;, &lt;code&gt;Linear&lt;/code&gt;, &lt;code&gt;Quadratic&lt;/code&gt;, &lt;code&gt;OneLayer&lt;/code&gt; and &lt;code&gt;Periodic&lt;/code&gt; kernels. The figure shows the MAP estimates of their weights over an ensemble of 32 particles. All of the high likelihood particles gave a large weight to the &lt;code&gt;Periodic&lt;/code&gt; component, low weights to &lt;code&gt;Linear&lt;/code&gt;, &lt;code&gt;Quadratic&lt;/code&gt; and &lt;code&gt;OneLayer&lt;/code&gt;, and a large weight to either &lt;code&gt;RBF&lt;/code&gt; or &lt;code&gt;Matern&lt;/code&gt;.
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi_5mU3VknB1oyCwNdCQj9kWTVV5J0BuylHB8W2LUK4sT6JpkOWdluZwh8_fKvRN5eSo2xBbQ0pRxDYa86IqML9H2-JZOmxxRJSm9ExG_PUr6U7iFl8nyp4lEaNpG3guYov3hPP3l9zifdu_iv_5aeP05OftccGqwJ7D0WAeMox_aWMGm3hN5nOkrj4BPxU/s868/image5.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="542" data-original-width="868" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi_5mU3VknB1oyCwNdCQj9kWTVV5J0BuylHB8W2LUK4sT6JpkOWdluZwh8_fKvRN5eSo2xBbQ0pRxDYa86IqML9H2-JZOmxxRJSm9ExG_PUr6U7iFl8nyp4lEaNpG3guYov3hPP3l9zifdu_iv_5aeP05OftccGqwJ7D0WAeMox_aWMGm3hN5nOkrj4BPxU/s16000/image5.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Parallel coordinates plot of the &lt;a href="https://www.probabilitycourse.com/chapter9/9_1_2_MAP_estimation.php"&gt;MAP&lt;/a&gt; estimates of the base kernel weights over 32 particles. The &lt;code&gt;sum_of_stumps&lt;/code&gt; model was trained on the N374 series from the M3 dataset (insert in blue). Darker lines correspond to particles with higher likelihoods.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
By using &lt;code&gt;WeightedSums&lt;/code&gt; as the inputs to other operators, it is possible to express rich combinatorial structures, while keeping models compact and the number of learnable weights small. As an example, we include the &lt;code&gt;sum_of_products&lt;/code&gt; model (illustrated in the figure below) which first creates a pairwise product of two &lt;code&gt;WeightedSums&lt;/code&gt;, and then a sum of the two products. By setting some of the weights to zero, we can create many different discrete structures. The total number of possible structures in this model is 2&lt;sup&gt;16&lt;/sup&gt;, since there are 16 base kernels that can be turned on or off. All these structures are explored implicitly by training just this one model.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh9VhSV6af55mkKxUKzpJJrqQiAV6WUWJ8HY9Q-5qcPB_mr8_P0lvrcGGkEUNe_-UB6Ri5VgWFkdHvRwEe7snZucQtvzMR_548jt4h2lbTzfnp7ZUeYFDmas7LwKc_9UAzdLE4gr8g9pVVkMXy9GU8qMUzrKfd9tjDEc2C4Ub6aXDzjHf2FjCryg_pWu39E/s1754/AutoBNN%20illustration.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="640" data-original-width="1754" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh9VhSV6af55mkKxUKzpJJrqQiAV6WUWJ8HY9Q-5qcPB_mr8_P0lvrcGGkEUNe_-UB6Ri5VgWFkdHvRwEe7snZucQtvzMR_548jt4h2lbTzfnp7ZUeYFDmas7LwKc_9UAzdLE4gr8g9pVVkMXy9GU8qMUzrKfd9tjDEc2C4Ub6aXDzjHf2FjCryg_pWu39E/s16000/AutoBNN%20illustration.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Illustration of the "sum_of_products" model. Each of the four WeightedSums have the same structure as the "sum_of_stumps" model.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
We have found, however, that certain combinations of kernels (e.g., the product of &lt;code&gt;Periodic&lt;/code&gt; and either the &lt;code&gt;Matern&lt;/code&gt; or &lt;code&gt;ExponentiatedQuadratic&lt;/code&gt;) lead to overfitting on many datasets. To prevent this, we have defined model classes like &lt;code&gt;sum_of_safe_shallow&lt;/code&gt; that exclude such products when performing structure discovery with &lt;code&gt;WeightedSums&lt;/code&gt;.
&lt;/p&gt;

&lt;p&gt;
For training, AutoBNN provides &lt;code&gt;AutoBnnMapEstimator&lt;/code&gt; and &lt;code&gt;AutoBnnMCMCEstimator&lt;/code&gt; to perform MAP and MCMC inference, respectively. Either estimator can be combined with any of the six &lt;a href="https://github.com/tensorflow/probability/blob/main/spinoffs/autobnn/autobnn/likelihoods.py"&gt;likelihood functions&lt;/a&gt;, including four based on normal distributions with different noise characteristics for continuous data and two based on the negative binomial distribution for count data.  
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgVzVWT-e-lcT53h75r2QJpR7iH9FAgCkpQY_oBNq7o1YoO4TkJ2GVpXLYcyY3RjOfgaXRM2LRII_jK31PbxTQF29yH1cTJRdI-XkXmnZMR_imlFv0uOuIPni3nW_vb1ercfuJuKHbrbuIA4bVR5EuGTs5iUHRXs-4WaA9wFEX54RwOJQt0BGMGfkNW4kxn/s1076/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="280" data-original-width="1076" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgVzVWT-e-lcT53h75r2QJpR7iH9FAgCkpQY_oBNq7o1YoO4TkJ2GVpXLYcyY3RjOfgaXRM2LRII_jK31PbxTQF29yH1cTJRdI-XkXmnZMR_imlFv0uOuIPni3nW_vb1ercfuJuKHbrbuIA4bVR5EuGTs5iUHRXs-4WaA9wFEX54RwOJQt0BGMGfkNW4kxn/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Result from running AutoBNN on the &lt;a href="https://gml.noaa.gov/ccgg/trends/"&gt;Mauna Loa CO2&lt;/a&gt; dataset in our example &lt;a href="https://github.com/tensorflow/probability/blob/main/discussion/examples/Forecasting_With_AutoBNN.ipynb"&gt;colab&lt;/a&gt;. The model captures the trend and seasonal component in the data. Extrapolating into the future, the mean prediction slightly underestimates the actual trend, while the 95% confidence interval gradually increases.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
To fit a model like in the figure above, all it takes is the following 10 lines of code, using the &lt;a href="https://scikit-learn.org/stable/"&gt;scikit-learn&lt;/a&gt;–inspired estimator interface:&lt;/p&gt;


&lt;pre class="prettyprint"&gt;import autobnn as ab

model = ab.operators.Add(
    bnns=(ab.kernels.PeriodicBNN(width=50),
          ab.kernels.LinearBNN(width=50),
          ab.kernels.MaternBNN(width=50)))

estimator = ab.estimators.AutoBnnMapEstimator(
    model, 'normal_likelihood_logistic_noise', jax.random.PRNGKey(42),
    periods=[12])

estimator.fit(my_training_data_xs, my_training_data_ys)
low, mid, high = estimator.predict_quantiles(my_training_data_xs)
&lt;/pre&gt;

&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
&lt;a href="https://github.com/tensorflow/probability/tree/main/spinoffs/autobnn"&gt;AutoBNN&lt;/a&gt; provides a powerful and flexible framework for building sophisticated time series prediction models. By combining the strengths of BNNs and GPs with compositional kernels, AutoBNN opens a world of possibilities for understanding and forecasting complex data. We invite the community to try the&amp;nbsp;&lt;a href="https://github.com/tensorflow/probability/blob/main/discussion/examples/Forecasting_With_AutoBNN.ipynb" target="_blank"&gt;colab&lt;/a&gt;, and leverage this library to innovate and solve real-world challenges. 
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;AutoBNN was written by Colin Carroll, Thomas Colthurst, Urs Köster and Srinivas Vasudevan. We would like to thank Kevin Murphy, Brian Patton and Feras Saad for their advice and feedback.&lt;/em&gt;
&lt;/p&gt;&lt;p&gt;&lt;/p&gt;</content><link href="http://blog.research.google/feeds/1799535679952845079/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/autobnn-probabilistic-time-series.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1799535679952845079" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1799535679952845079" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/autobnn-probabilistic-time-series.html" rel="alternate" title="AutoBNN: Probabilistic time series forecasting with compositional bayesian neural networks" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgd5Wc54p1HvgIokpazxDsMo1u6i9wg3ovpNOiFc4-wYwebETvjs9-hm2wxZ4osNbBAxhet8To3hwGg-whFScksHQB_BP1kS4Z8Cu7FQT2bjVtJl4trPid-OxCyYocwyRTN66tuvAedu9z0FepBg4zZvmLbLxY6uuib8p5jVH2kfb3RxT_HMABsKMXuSFXr/s72-c/AutoBNN.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-7061041222399769838</id><published>2024-03-20T13:54:00.000-07:00</published><updated>2024-03-20T13:54:06.249-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Health"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="User Experience"/><title type="text">Computer-aided diagnosis for lung cancer screening</title><content type="html">&lt;span class="byline-author"&gt;Posted by Atilla Kiraly, Software Engineer, and Rory Pilgrim, Product Manager, Google Research &lt;/span&gt;


&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjFpuCd82OUmuS2oG2cVir_ZgeOyUpFndr-kCq8V4pDv6fzxeyViBJymfVt5FFUqgkM_X57msxNv84XBtaXs2FsD7R8_tNqtH6D8X_KiMtZRaJ37JphQsvM35_gIk-4Tn2eEYvrInjMLV5ouwhRJv3Oqb30Z71P546NszeURINBoJnlWnzgASn-6D9YFwZo/s320/PULMA%20hero.jpg" style="display: none;" /&gt;

&lt;p&gt;
Lung cancer is the leading cause of cancer-related deaths globally with &lt;a href="https://www.who.int/news-room/fact-sheets/detail/cancer#:~:text=The%20most%20common%20causes%20of,rectum%20(916%20000%20deaths)%3B"&gt;1.8 million deaths&lt;/a&gt; reported in 2020. Late diagnosis dramatically reduces the chances of survival. &lt;a href="https://www.cdc.gov/cancer/lung/basic_info/screening.htm"&gt;Lung cancer screening&lt;/a&gt; via &lt;a href="https://www.cancer.gov/about-cancer/diagnosis-staging/ct-scans-fact-sheet#:~:text=indicate%20real%20problems.-,Lung%20cancer,-Low%2Ddose%20CT"&gt;computed tomography&lt;/a&gt; (CT), which provides a detailed 3D image of the lungs, has been shown to reduce mortality in high-risk populations by at least 20% by detecting potential signs of cancers earlier. In the US, screening involves annual scans, with some countries or cases recommending more or less frequent scans. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
The &lt;a href="https://www.uspreventiveservicestaskforce.org/uspstf/recommendation/lung-cancer-screening"&gt;United States Preventive Services Task Force&lt;/a&gt; recently expanded lung cancer screening recommendations by &lt;a href="https://pubmed.ncbi.nlm.nih.gov/34636916/"&gt;roughly 80%&lt;/a&gt;, which is expected to increase screening access for women and racial and ethnic minority groups. However, false positives (i.e., incorrectly reporting a potential cancer in a cancer-free patient) can cause anxiety and lead to unnecessary procedures for patients while increasing costs for the healthcare system. Moreover, efficiency in screening a large number of individuals can be challenging depending on healthcare infrastructure and radiologist availability.
&lt;/p&gt;


&lt;p&gt;
At Google we have previously developed &lt;a href="https://blog.google/technology/health/lung-cancer-prediction/"&gt;machine learning (ML) models for lung cancer detection&lt;/a&gt;, and have evaluated their ability to automatically detect and classify regions that show signs of potential cancer. Performance has been shown to be comparable to that of specialists in detecting possible cancer. While they have achieved high performance, effectively communicating findings in realistic environments is necessary to realize their full potential.
&lt;/p&gt;

&lt;p&gt;
To that end, in “&lt;a href="https://pubs.rsna.org/doi/10.1148/ryai.230079"&gt;Assistive AI in Lung Cancer Screening: A Retrospective Multinational Study in the US and Japan&lt;/a&gt;”, published in &lt;em&gt;&lt;a href="https://pubs.rsna.org/journal/ai"&gt;Radiology AI&lt;/a&gt;&lt;/em&gt;, we investigate how ML models can effectively communicate findings to radiologists. We also introduce a generalizable user-centric interface to help radiologists leverage such models for lung cancer screening. The system takes CT imaging as input and outputs a cancer suspicion rating using four categories (no suspicion, probably benign, suspicious, highly suspicious) along with the corresponding regions of interest. We evaluate the system’s utility in improving clinician performance through randomized reader studies in both the US and Japan, using the local cancer scoring systems (&lt;a href="https://www.acr.org/-/media/ACR/Files/RADS/Lung-RADS/LungRADSAssessmentCategoriesv1-1.pdf"&gt;Lung-RADSs V1.1&lt;/a&gt; and &lt;a href="https://www.jscts.org/pdf/guideline/gls3rdfig_english130621.pdf"&gt;Sendai Score&lt;/a&gt;) and image viewers that mimic realistic settings. We found that reader specificity increases with model assistance in both reader studies. To accelerate progress in conducting similar studies with ML models, we have &lt;a href="https://github.com/Google-Health/google-health/tree/master/ct_dicom"&gt;open-sourced code&lt;/a&gt; to process CT images and generate images compatible with the &lt;a href="https://en.wikipedia.org/wiki/Picture_archiving_and_communication_system"&gt;picture archiving and communication system&lt;/a&gt; (PACS) used by radiologists. 
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Developing an interface to communicate model results&lt;/h2&gt;


&lt;p&gt;
Integrating ML models into radiologist workflows involves understanding the nuances and goals of their tasks to meaningfully support them. In the case of lung cancer screening, hospitals follow various country-specific guidelines that are regularly updated. For example, in the US, Lung-RADs V1.1 assigns an &lt;a href="https://www.acr.org/-/media/ACR/Files/RADS/Lung-RADS/LungRADSAssessmentCategoriesv1-1.pdf"&gt;alpha-numeric score&lt;/a&gt; to indicate the lung cancer risk and follow-up recommendations&lt;em&gt;. &lt;/em&gt;When assessing patients, radiologists load the CT in their workstation to read the case, find lung nodules or lesions, and apply set guidelines to determine follow-up decisions. 
&lt;/p&gt;


&lt;p&gt;
Our first step was to improve the &lt;a href="https://blog.google/technology/health/lung-cancer-prediction/"&gt;previously developed ML models&lt;/a&gt; through additional training data and architectural improvements, including &lt;a href="https://research.google/pubs/attention-is-all-you-need/"&gt;self-attention&lt;/a&gt;. Then, instead of targeting specific guidelines, we experimented with a complementary way of communicating AI results independent of guidelines or their particular versions. Specifically, the system output offers a suspicion rating and localization (regions of interest) for the user to consider in conjunction with their own specific guidelines. The interface produces output images directly associated with the CT study, requiring no changes to the user’s workstation. The radiologist only needs to review a small set of additional images. There is no other change to their system or interaction with the system.
&lt;/p&gt;


&lt;p&gt;


&lt;/p&gt;&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiChGqKLOWQAzrIzk294q6i6XuUoR1ul0qoTAR8RHQw-bZT-ulyruug-HNY8f2em7ZgzHE1UP6yQbe4plM0gkmXu6KwcTmsNogbr6FjTGzSDrBEDFhVLQ4TdbxVp_bbB21gA_jR84-1r9ly-O5HXqOzuZERgJyjFSYtZty7h6J3UErWsP0-DoQ1pFZtyjiw/s857/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="436" data-original-width="857" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiChGqKLOWQAzrIzk294q6i6XuUoR1ul0qoTAR8RHQw-bZT-ulyruug-HNY8f2em7ZgzHE1UP6yQbe4plM0gkmXu6KwcTmsNogbr6FjTGzSDrBEDFhVLQ4TdbxVp_bbB21gA_jR84-1r9ly-O5HXqOzuZERgJyjFSYtZty7h6J3UErWsP0-DoQ1pFZtyjiw/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Example of the assistive lung cancer screening system outputs. Results for the radiologist’s evaluation are visualized on the location of the CT volume where the suspicious lesion is found. The overall suspicion is displayed at the top of the CT images. Circles highlight the suspicious lesions while squares show a rendering of the same lesion from a different perspective, called a sagittal view.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
The assistive lung cancer screening system comprises 13 models and has a high-level architecture similar to the end-to-end system used in &lt;a href="https://blog.google/technology/health/lung-cancer-prediction/"&gt;prior work&lt;/a&gt;. The models coordinate with each other to first segment the lungs, obtain an overall assessment, locate three suspicious regions, then use the information to assign a suspicion rating to each region. The system was deployed on Google Cloud using a &lt;a href="https://cloud.google.com/kubernetes-engine"&gt;Google Kubernetes Engine&lt;/a&gt; (GKE) that pulled the images, ran the ML models, and provided results. This allows scalability and directly connects to servers where the images are stored in &lt;a href="https://cloud.google.com/healthcare-api/docs/concepts/dicom"&gt;DICOM stores&lt;/a&gt;.
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlQLk7XcQtSX367ubw0D0TtTqZQg-H69p63qtVrGir3UfJcYUyys0n_Nks-YqURRklRWllhSKdH-FFjRvfkb9mGxEmL191sfpAclKD085x-u20FJS9BWJGULyLk0foVGKfq5T5F7_hx7Z4xHu1ZeHPLM63HUCaiCrkt8BThhiImts9epWqqCE2s0BLeoWU/s646/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="394" data-original-width="646" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlQLk7XcQtSX367ubw0D0TtTqZQg-H69p63qtVrGir3UfJcYUyys0n_Nks-YqURRklRWllhSKdH-FFjRvfkb9mGxEmL191sfpAclKD085x-u20FJS9BWJGULyLk0foVGKfq5T5F7_hx7Z4xHu1ZeHPLM63HUCaiCrkt8BThhiImts9epWqqCE2s0BLeoWU/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Outline of the Google Cloud deployment of the assistive lung cancer screening system and the directional calling flow for the individual components that serve the images and compute results. Images are served to the viewer and to the system using Google Cloud services. The system is run on a Google Kubernetes Engine that pulls the images, processes them, and writes them back into the DICOM store.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
  
&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Reader studies &lt;/h2&gt;


&lt;p&gt;
To evaluate the system’s utility in improving clinical performance, we conducted two reader studies (i.e., experiments designed to assess clinical performance comparing expert performance with and without the aid of a technology) with 12 radiologists using pre-existing, de-identified CT scans. We presented 627 challenging cases to 6 US-based and 6 Japan-based radiologists. In the experimental setup, readers were divided into two groups that read each case twice, with and without assistance from the model. Readers were asked to apply scoring guidelines they typically use in their clinical practice and report their overall suspicion of cancer for each case. We then compared the results of the reader’s responses to measure the impact of the model on their workflow and decisions. The score and suspicion level were judged against the actual cancer outcomes of the individuals to measure sensitivity, specificity, and &lt;a href="https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc#:~:text=AUC%20stands%20for%20%22Area%20under,across%20all%20possible%20classification%20thresholds."&gt;area under the ROC curve&lt;/a&gt; (AUC) values. These were compared with and without assistance.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgmiP7GWIMf_TKezxSK0sM8EOtfm2M3QoZtgvYfcjacMm2atdilirD93ftlu_QlyusIu_ocC6R0iHX1eXtHrU6g1yLUWnZ1Bq0FJ0nXEjTezptuSxGbpwDFIkQGeZrFPmwXV3IYvyzJYPCEhp4etRNzhGmHbbfQAwntOm4ZhQNpuXbei5sfN6MqsQXJctVH/s794/image3.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="297" data-original-width="794" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgmiP7GWIMf_TKezxSK0sM8EOtfm2M3QoZtgvYfcjacMm2atdilirD93ftlu_QlyusIu_ocC6R0iHX1eXtHrU6g1yLUWnZ1Bq0FJ0nXEjTezptuSxGbpwDFIkQGeZrFPmwXV3IYvyzJYPCEhp4etRNzhGmHbbfQAwntOm4ZhQNpuXbei5sfN6MqsQXJctVH/s16000/image3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;A multi-case multi-reader study involves each case being reviewed by each reader twice, once with ML system assistance and once without. In this visualization one reader first reviews Set A without assistance (&lt;strong&gt;blue&lt;/strong&gt;) and then with assistance (&lt;strong&gt;orange&lt;/strong&gt;) after a wash-out period. A second reader group follows the opposite path by reading the same set of cases Set A with assistance first. Readers are randomized to these groups to remove the effect of ordering.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
The ability to conduct these studies using the same interface highlights its generalizability to completely different cancer scoring systems, and the generalization of the model and assistive capability to different patient populations. Our study results demonstrated that when radiologists used the system in their clinical evaluation, they had an increased ability to correctly identify lung images without actionable lung cancer findings (i.e., &lt;em&gt;specificity&lt;/em&gt;) by an absolute 5–7% compared to when they didn’t use the assistive system. This potentially means that for every 15–20 patients screened, one may be able to avoid unnecessary follow-up procedures, thus reducing their anxiety and the burden on the health care system. This can, in turn, help improve the sustainability of lung cancer screening programs, particularly as &lt;a href="https://pubmed.ncbi.nlm.nih.gov/34636916/"&gt;more people become eligible for screening&lt;/a&gt;. 
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiDMKrqRR9njVuYSLV0Nzb7-MXdpyJTSofvvxFhyendGwnM9pddFyy48MVBWKsadYMUp1RGQBNL77vC0gCvjZ_fIsIQ8ZhGHZmy52srebu49xIL4wYkuvyftssXzvohoSoBKt9C2uwua6gz4ReO4LQvfMbhdrgtXvcYb3JruZAchta2n5MhU41pTpJLyMJI/s1999/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="824" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiDMKrqRR9njVuYSLV0Nzb7-MXdpyJTSofvvxFhyendGwnM9pddFyy48MVBWKsadYMUp1RGQBNL77vC0gCvjZ_fIsIQ8ZhGHZmy52srebu49xIL4wYkuvyftssXzvohoSoBKt9C2uwua6gz4ReO4LQvfMbhdrgtXvcYb3JruZAchta2n5MhU41pTpJLyMJI/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Reader specificity increases with ML model assistance in both the US-based and Japan-based reader studies. Specificity values were derived from reader scores from actionable findings (something suspicious was found) versus no actionable findings, compared against the true cancer outcome of the individual.  Under model assistance, readers flagged fewer cancer-negative individuals for follow-up visits. Sensitivity for cancer positive individuals remained the same.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Translating this into real-world impact through partnership &lt;/h2&gt;


&lt;p&gt;
The system results demonstrate the potential for fewer follow-up visits, reduced anxiety, as well lower overall costs for lung cancer screening. In an effort to translate this research into real-world clinical impact, we are working with:  &lt;a href="https://deephealth.com/"&gt;DeepHealth&lt;/a&gt;, a leading AI-powered health informatics provider; and &lt;a href="https://apolloradiologyintl.com/"&gt;Apollo Radiology International&lt;/a&gt; a leading provider of Radiology services in India to explore paths for incorporating this system into future products. In addition, we are looking to help other researchers studying how best to integrate ML model results into clinical workflows by &lt;a href="https://github.com/Google-Health/google-health/tree/master/ct_dicom"&gt;open sourcing code&lt;/a&gt; used for the reader study and incorporating the insights described in this blog. We hope that this will help accelerate medical imaging researchers looking to conduct reader studies for their AI models, and catalyze translational research in the field.  
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;Key contributors to this project include Corbin Cunningham, Zaid Nabulsi, Ryan Najafi, Jie Yang, Charles Lau, Joseph R. Ledsam, Wenxing Ye, Diego Ardila, Scott M. McKinney, Rory Pilgrim, Hiroaki Saito, Yasuteru Shimamura, Mozziyar Etemadi, Yun Liu, David Melnick, Sunny Jansen, Nadia Harhen, David P. Nadich, Mikhail Fomitchev, Ziyad Helali, Shabir Adeel, Greg S. Corrado, Lily Peng, Daniel Tse, Shravya Shetty, Shruthi Prabhakara, Neeral Beladia, and Krish Eswaran. Thanks to Arnav Agharwal and Andrew Sellergren for their open sourcing support and Vivek Natarajan and Michael D. Howell for their feedback. Sincere appreciation also goes to the radiologists who enabled this work with their image interpretation and annotation efforts throughout the study, and Jonny Wong and Carli Sampson for coordinating the reader studies.&lt;/em&gt;
&lt;/p&gt;&lt;p&gt;&lt;/p&gt;</content><link href="http://blog.research.google/feeds/7061041222399769838/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/computer-aided-diagnosis-for-lung.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/7061041222399769838" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/7061041222399769838" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/computer-aided-diagnosis-for-lung.html" rel="alternate" title="Computer-aided diagnosis for lung cancer screening" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjFpuCd82OUmuS2oG2cVir_ZgeOyUpFndr-kCq8V4pDv6fzxeyViBJymfVt5FFUqgkM_X57msxNv84XBtaXs2FsD7R8_tNqtH6D8X_KiMtZRaJ37JphQsvM35_gIk-4Tn2eEYvrInjMLV5ouwhRJv3Oqb30Z71P546NszeURINBoJnlWnzgASn-6D9YFwZo/s72-c/PULMA%20hero.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-4615278636568583418</id><published>2024-03-20T09:06:00.000-07:00</published><updated>2024-03-20T09:06:06.753-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Environment"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><title type="text">Using AI to expand global access to reliable flood forecasts</title><content type="html">&lt;span class="byline-author"&gt;Posted by Yossi Matias, VP Engineering &amp;amp; Research, and Grey Nearing, Research Scientist, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgABDUlqCHMxNY-QfEftM_9yPy1z4jr1odB-_kSP79yjk6igtpPJNFIocQOKDRnZ3VLmqrI9tqX-dCHpcYtnSx96y9X9V9knp1CiAREvfgZX71D0XpWZNgPdZOI7aMW3POigHJ2rLeA1G1asaAPO3KIB3j0WzUr5C707I7p0L_itspYYEhYDhDTzd39tNUD/s320/Flood%20forecasting%20hero%20image.jpg" style="display: none;" /&gt;

&lt;p&gt;
Floods are the &lt;a href="https://openknowledge.worldbank.org/server/api/core/bitstreams/e218989e-8b3b-5f8c-944c-06e9812215aa/content"&gt;most common natural disaster&lt;/a&gt;, and are responsible for roughly &lt;a href="https://www.swissre.com/risk-knowledge/mitigating-climate-risk/floods.html"&gt;$50 billion&lt;/a&gt; in annual financial damages worldwide. The &lt;a href="https://library.wmo.int/records/item/57630-2021-state-of-climate-services-water?offset=1#:~:text=WMO%2DNo.,1278&amp;amp;text=More%20than%202%20billion%20people,for%20the%20past%2020%20years."&gt;rate of flood-related disasters has more than doubled&lt;/a&gt; since the year 2000 partly &lt;a href="https://www.nature.com/articles/s41598-020-70816-2"&gt;due to climate change&lt;/a&gt;. Nearly &lt;a href="https://openknowledge.worldbank.org/server/api/core/bitstreams/e218989e-8b3b-5f8c-944c-06e9812215aa/content"&gt;1.5 billion people&lt;/a&gt;, making up 19% of the world’s population, are exposed to substantial risks from severe flood events. Upgrading early warning systems to make accurate and timely information accessible to these populations &lt;a href="https://elibrary.worldbank.org/doi/abs/10.1596/1813-9450-6058"&gt;can save thousands of lives per year&lt;/a&gt;. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
Driven by the potential impact of reliable flood forecasting on people’s lives globally, we started our flood forecasting effort in 2017. Through this &lt;a href="https://blog.google/technology/ai/google-ai-global-flood-forecasting/"&gt;multi-year journey&lt;/a&gt;, we advanced research over the years hand-in-hand with building a real-time operational flood forecasting system that &lt;a href="https://blog.google/technology/ai/expanding-our-ml-based-flood-forecasting/"&gt;provides alerts&lt;/a&gt; on Google Search, Maps, Android notifications and through the &lt;a href="http://g.co/floodhub"&gt;Flood Hub&lt;/a&gt;. However, in order to &lt;a href="https://blog.google/outreach-initiatives/sustainability/flood-hub-ai-flood-forecasting-more-countries/"&gt;scale globally&lt;/a&gt;, especially in places where accurate local data is not available, more research advances were required.
&lt;/p&gt;

&lt;p&gt;
In “&lt;a href="https://www.nature.com/articles/s41586-024-07145-1"&gt;Global prediction of extreme floods in ungauged watersheds&lt;/a&gt;”, published in &lt;em&gt;&lt;a href="https://www.nature.com/"&gt;Nature&lt;/a&gt;&lt;/em&gt;, we demonstrate how machine learning (ML) technologies can significantly improve global-scale &lt;a href="https://sites.research.google/floodforecasting/"&gt;flood forecasting&lt;/a&gt; relative to the current state-of-the-art for countries where flood-related data is scarce. With these AI-based technologies we extended the reliability of currently-available global nowcasts, on average, from zero to five days, and improved forecasts across regions in Africa and Asia to be similar to what are currently available in Europe. The evaluation of the models was conducted in collaboration with the European Center for Medium Range Weather Forecasting (&lt;a href="https://www.ecmwf.int/"&gt;ECMWF&lt;/a&gt;).
&lt;/p&gt;

&lt;p&gt;
These technologies also enable &lt;a href="http://g.co/floodhub"&gt;Flood Hub&lt;/a&gt; to provide real-time river forecasts up to seven days in advance, &lt;a href="https://blog.google/outreach-initiatives/sustainability/flood-hub-ai-flood-forecasting-more-countries/"&gt;covering&lt;/a&gt; river reaches across over 80 countries. This information can be used by people, communities, governments and international organizations to take anticipatory action to help protect vulnerable populations.
&lt;/p&gt;

&lt;br /&gt;
&lt;div class="separator" style="clear: both; text-align: center;"&gt;&lt;iframe allowfullscreen="" class="BLOG_video_class" frameborder="0" height="360" src="https://www.youtube.com/embed/ET04pDj-RvM?si=WJJXEtwJqtyMRuC_?rel=0&amp;amp;" width="640" youtube-src-id="[ET04pDj-RvM]"&gt;&lt;/iframe&gt;&lt;/div&gt;
&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Flood forecasting at Google &lt;/h2&gt;


&lt;p&gt;
The ML models that power the FloodHub tool are the product of many years of research, conducted in collaboration with several partners, including academics, governments, international organizations, and NGOs. 
&lt;/p&gt;

&lt;p&gt;
In 2018, we &lt;a href="https://blog.google/products/search/helping-keep-people-safe-ai-enabled-flood-forecasting/"&gt;launched a pilot&lt;/a&gt; early warning system in the Ganges-Brahmaputra river basin in India, with the &lt;a href="https://arxiv.org/abs/1901.09583"&gt;hypothesis&lt;/a&gt; that ML could help address the challenging problem of reliable flood forecasting at scale. The pilot was further &lt;a href="https://blog.google/technology/ai/tracking-our-progress-on-flood-forecasting/"&gt;expanded&lt;/a&gt; the following year &lt;a href="https://ai.googleblog.com/2019/09/an-inside-look-at-flood-forecasting.html"&gt;via the combination&lt;/a&gt; of an inundation model, real-time water level measurements, the creation of an elevation map and hydrologic modeling.
&lt;/p&gt;

&lt;p&gt;
In &lt;a href="https://ai.googleblog.com/2019/03/a-summary-of-google-flood-forecasting.html"&gt;collaboration&lt;/a&gt; with academics, and, in particular, with the &lt;a href="https://www.jku.at/en/institute-for-machine-learning/"&gt;JKU Institute for Machine Learning&lt;/a&gt; we explored ML-based hydrologic models, showing that &lt;a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/"&gt;LSTM&lt;/a&gt;-based models could &lt;a href="https://hess.copernicus.org/articles/23/5089/2019/"&gt;produce more accurate simulations&lt;/a&gt; than traditional conceptual and physics-based &lt;a href="https://en.wikipedia.org/wiki/Hydrological_model"&gt;hydrology models&lt;/a&gt;. This research led to &lt;a href="https://blog.research.google/2020/09/the-technology-behind-our-recent.html"&gt;flood forecasting improvements&lt;/a&gt; that enabled the &lt;a href="https://blog.google/technology/ai/flood-forecasts-india-bangladesh/"&gt;expansion&lt;/a&gt; of our forecasting coverage to include all of India and Bangladesh. We also worked with researchers at Yale University to test technological interventions that increase the &lt;a href="https://egc.yale.edu/about/perspectives/pande-and-coauthors-using-technology-save-lives-during-indias-monsoon-season"&gt;reach and impact&lt;/a&gt; of flood warnings.
&lt;/p&gt;

&lt;p&gt;
Our hydrological models predict river floods by processing publicly available weather data like precipitation and physical watershed information. Such models must be calibrated to long data records from &lt;a href="https://en.wikipedia.org/wiki/Stream_gauge"&gt;streamflow gauging stations&lt;/a&gt; in individual rivers. A low percentage of global river watersheds (basins) have streamflow gauges, which are expensive but necessary to supply relevant data, and it’s challenging for hydrological simulation and forecasting to provide &lt;a href="https://www.tandfonline.com/doi/full/10.1080/02626667.2013.803183"&gt;predictions in basins&lt;/a&gt; that lack this infrastructure. Lower &lt;a href="https://www.pnas.org/doi/full/10.1073/pnas.1414439112"&gt;gross domestic product&lt;/a&gt; (GDP) is correlated with increased &lt;a href="https://www.pnas.org/doi/full/10.1073/pnas.1414439112"&gt;vulnerability to flood risks&lt;/a&gt;, and there is an inverse correlation between national GDP and the amount of publicly available data in a country. ML helps to address this problem by allowing a &lt;a href="https://www.pnas.org/doi/full/10.1073/pnas.1414439112"&gt;single model to be trained on all available river data&lt;/a&gt; and to be applied to ungauged basins where &lt;a href="https://agupubs.onlinelibrary.wiley.com/doi/10.1029/2020wr028091"&gt;no data are available&lt;/a&gt;. In this way, models can be trained globally, and can make predictions for any river location.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjxQUgMZAg0tVPN5LrxYbhpn3dukUCVogsWPgynrYNjFfbXpwK0RF79rYvK9kyehrha0F-vMLZR2eqBWdKCuGter6VoZrbCKnROTNn_hmOXBDxWmOFhFRvyg36ghO0B08fsQv7cqXdyngtfgCAgF5LhONs5VDzyvYjxzEYejVN3FxvzRs8w9Q5EeGJJTr3O/s1051/Streamflow%20data%20from%20the%20Global%20Runoff%20Data%20Center.jpg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="788" data-original-width="1051" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjxQUgMZAg0tVPN5LrxYbhpn3dukUCVogsWPgynrYNjFfbXpwK0RF79rYvK9kyehrha0F-vMLZR2eqBWdKCuGter6VoZrbCKnROTNn_hmOXBDxWmOFhFRvyg36ghO0B08fsQv7cqXdyngtfgCAgF5LhONs5VDzyvYjxzEYejVN3FxvzRs8w9Q5EeGJJTr3O/s16000/Streamflow%20data%20from%20the%20Global%20Runoff%20Data%20Center.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;There is an inverse (log-log) correlation between the amount of publicly available streamflow data in a country and national GDP. Streamflow data from the &lt;a href="https://www.bafg.de/GRDC/EN/Home/homepage_node.html"&gt;Global Runoff Data Center&lt;/a&gt;.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
Our academic collaborations led to ML research that developed methods to &lt;a href="https://agupubs.onlinelibrary.wiley.com/doi/10.1029/2020wr028091"&gt;estimate uncertainty in river forecasts&lt;/a&gt; and showed how ML river forecast models &lt;a href="https://hess.copernicus.org/articles/25/2685/2021/hess-25-2685-2021-relations.html"&gt;synthesize information from multiple data sources&lt;/a&gt;. They demonstrated that these models can &lt;a href="https://hess.copernicus.org/articles/26/3377/2022/hess-26-3377-2022.html"&gt;simulate extreme events reliably&lt;/a&gt;, even when those events are not part of the training data. In an effort to &lt;a href="https://blog.research.google/2023/04/directing-ml-toward-natural-hazard.html"&gt;contribute&lt;/a&gt; to open science, in 2023 we open-sourced a community-driven dataset for large-sample hydrology in &lt;em&gt;&lt;a href="https://www.nature.com/articles/s41597-023-01975-w"&gt;Nature Scientific Data&lt;/a&gt;&lt;/em&gt;. 
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;The river forecast model&lt;/h2&gt;


&lt;p&gt;
Most hydrology models used by national and international agencies for flood forecasting and river modeling are state-space models, which depend only on daily inputs (e.g., precipitation, temperature, etc.) and the current state of the system (e.g., soil moisture, snowpack, etc.). LSTMs are a variant of state-space models and work by defining a neural network that represents a single time step, where input data (such as current weather conditions) are processed to produce updated state information and output values (streamflow) for that time step. LSTMs are applied sequentially to make time-series predictions, and in this sense, behave similarly to how scientists typically conceptualize hydrologic systems. Empirically, we have found that &lt;a href="https://hess.copernicus.org/articles/23/5089/2019/"&gt;LSTMs perform well&lt;/a&gt; on the task of river forecasting.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgMfiw33NkHO8CQsYGWSZ91xhPx0iDONFLe8WZuRWDsoi8RRv7pHlF6M8eDLEWpO8lZECUfGi59_NsMXO8ASDZQ9xxrB87mupNTPpioKT0wRgSSc1FwYDmfCUWyooGGZmvMhZv0RDcWJVslQOPvRNOK_B6dXUGsnijSl-W-lICOIbALAwNC2PNEmqqXhv6g/s960/image1.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="540" data-original-width="960" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgMfiw33NkHO8CQsYGWSZ91xhPx0iDONFLe8WZuRWDsoi8RRv7pHlF6M8eDLEWpO8lZECUfGi59_NsMXO8ASDZQ9xxrB87mupNTPpioKT0wRgSSc1FwYDmfCUWyooGGZmvMhZv0RDcWJVslQOPvRNOK_B6dXUGsnijSl-W-lICOIbALAwNC2PNEmqqXhv6g/s16000/image1.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;A diagram of the LSTM, which is a neural network that operates sequentially in time. An accessible primer can be found &lt;a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/"&gt;here&lt;/a&gt;.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
Our river forecast model uses two LSTMs applied sequentially: (1) a “hindcast” LSTM ingests historical weather data (dynamic hindcast features) up to the present time (or rather, the issue time of a forecast), and (2) a “forecast” LSTM ingests states from the hindcast LSTM along with forecasted weather data (dynamic forecast features) to make future predictions. One year of historical weather data are input into the hindcast LSTM, and seven days of forecasted weather data are input into the forecast LSTM. Static features include geographical and geophysical characteristics of watersheds that are input into both the hindcast and forecast LSTMs and allow the model to learn different hydrological behaviors and responses in various types of watersheds. 
&lt;/p&gt;

&lt;p&gt;
Output from the forecast LSTM is fed into a “head” layer that uses &lt;a href="https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf"&gt;mixture density networks&lt;/a&gt; to produce a probabilistic forecast (i.e., predicted parameters of a probability distribution over streamflow). Specifically, the model predicts the parameters of a mixture of heavy-tailed probability density functions, called &lt;a href="https://en.wikipedia.org/wiki/Asymmetric_Laplace_distribution"&gt;asymmetric Laplacian distributions&lt;/a&gt;, at each forecast time step. The result is a mixture density function, called a &lt;a href="https://proceedings.neurips.cc/paper_files/paper/2019/file/d80126524c1e9641333502c664fc6ca1-Paper.pdf"&gt;Countable Mixture of Asymmetric Laplacians&lt;/a&gt; (CMAL) distribution, which represents a probabilistic prediction of the volumetric flow rate in a particular river at a particular time. 
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjVPR4LA0EbJyAesDg4HvrMdxgG_0wiyLqJveir2Ryy06qDNVshkM2-zHvMj_y1LEBXOSm7ajMx2qzYCLNQrQ3dm8TRicy_wkTVtM4Xio_mhQPsgaSiN3sm3J8BBNYNpxWQbSm_aTSMyRW9UyIEWAAT9secPekdYNzyKRrXwgm10-ksyeUzTFRydXnt_Wai/s960/LSTM-based%20river%20forecast%20model.jpeg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="480" data-original-width="960" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjVPR4LA0EbJyAesDg4HvrMdxgG_0wiyLqJveir2Ryy06qDNVshkM2-zHvMj_y1LEBXOSm7ajMx2qzYCLNQrQ3dm8TRicy_wkTVtM4Xio_mhQPsgaSiN3sm3J8BBNYNpxWQbSm_aTSMyRW9UyIEWAAT9secPekdYNzyKRrXwgm10-ksyeUzTFRydXnt_Wai/s16000/LSTM-based%20river%20forecast%20model.jpeg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;LSTM-based river forecast model architecture. Two LSTMs are applied in sequence, one ingesting historical weather data and one ingesting forecasted weather data. The model outputs are the parameters of a probability distribution over streamflow at each forecasted timestep.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Input and training data&lt;/h2&gt;


&lt;p&gt;
The model uses three types of publicly available data inputs, mostly from governmental sources:
&lt;/p&gt;
&lt;ol&gt;

&lt;li&gt;&lt;em&gt;Static watershed attributes representing geographical and geophysical variables:&lt;/em&gt; From the &lt;a href="https://www.hydrosheds.org/hydroatlas"&gt;HydroATLAS project&lt;/a&gt;, including data like long-term climate indexes (precipitation, temperature, snow fractions), land cover, and anthropogenic attributes (e.g., a nighttime lights index as a proxy for human development). 

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Historical meteorological time-series data&lt;/em&gt;: Used to spin up the model for one year prior to the issue time of a forecast. The data comes from &lt;a href="https://gpm.nasa.gov/data/imerg"&gt;NASA IMERG&lt;/a&gt;, &lt;a href="https://psl.noaa.gov/data/gridded/data.cpc.globalprecip.html"&gt;NOAA  CPC  Global Unified Gauge-Based Analysis of Daily Precipitation&lt;/a&gt;, and the &lt;a href="https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-land?tab=overview"&gt;ECMWF ERA5-land reanalysis&lt;/a&gt;. Variables include daily total precipitation, air temperature, solar and thermal radiation, snowfall, and surface pressure. 

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Forecasted meteorological time series over a seven-day forecast horizon&lt;/em&gt;: Used as input for the forecast LSTM. These data are the same meteorological variables listed above, and come from the &lt;a href="https://www.ecmwf.int/en/forecasts/datasets/set-i"&gt;ECMWF HRES atmospheric model&lt;/a&gt;.
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;
Training data are daily streamflow values from the &lt;a href="https://www.bafg.de/GRDC/EN/Home/homepage_node.html"&gt;Global Runoff Data Center&lt;/a&gt; over the time period 1980 - 2023. A single streamflow forecast model is trained using data from 5,680 diverse watershed streamflow gauges (shown below) to improve &lt;a href="https://eartharxiv.org/repository/view/6363/"&gt;accuracy&lt;/a&gt;.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhJZa8BMczHa_WiWNB1FJvPgEcw5O6U_IumoXBvI3gB_cIqrbte2SZKu_Msr1MudCVPv3YF6L3BweAC0hhMkET634isx6xzUswrYfDwp8oueoWJ7c3hf0os-RIsaNrdgAboc7HUly0rGtuBt6OVQ-MnY5P44DKOXSHKYl_T-gMz5z0ek8CHk0lIx45fnZYU/s1417/gauge_locations_map(1).jpg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="689" data-original-width="1417" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhJZa8BMczHa_WiWNB1FJvPgEcw5O6U_IumoXBvI3gB_cIqrbte2SZKu_Msr1MudCVPv3YF6L3BweAC0hhMkET634isx6xzUswrYfDwp8oueoWJ7c3hf0os-RIsaNrdgAboc7HUly0rGtuBt6OVQ-MnY5P44DKOXSHKYl_T-gMz5z0ek8CHk0lIx45fnZYU/s16000/gauge_locations_map(1).jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Location of 5,680 streamflow gauges that supply training data for the river forecast model from the &lt;a href="https://www.bafg.de/GRDC/EN/Home/homepage_node.html"&gt;Global Runoff Data Center&lt;/a&gt;.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;
  
  
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;  
&lt;h2&gt;Improving on the current state-of-the-art&lt;/h2&gt;


&lt;p&gt;
We compared our river forecast model with &lt;a href="https://www.globalfloods.eu/"&gt;GloFAS version 4&lt;/a&gt;, the current state-of-the-art global flood forecasting system. These experiments showed that ML can provide accurate warnings earlier and over larger and more impactful events. 
&lt;/p&gt;

&lt;p&gt;
The figure below shows the distribution of &lt;a href="https://en.wikipedia.org/wiki/F-score"&gt;F1 scores&lt;/a&gt; when predicting different severity events at river locations around the world, with plus or minus 1 day accuracy. F1 scores are an average of precision and recall and event severity is measured by &lt;a href="https://en.wikipedia.org/wiki/Return_period#:~:text=A%20return%20period%2C%20also%20known,river%20discharge%20flows%20to%20occur."&gt;return period&lt;/a&gt;. For example, a 2-year return period event is a volume of streamflow that is expected to be exceeded on average once every two years. Our model achieves reliability scores at up to 4-day or 5-day lead times that are similar to or better, on average, than the reliability of GloFAS nowcasts (0-day lead time). 
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgjwzwV6QYl4yIlWs1xdHz2HRiNi2I8WUaTGBVlVvA4guppIGpJ3RMj8ypE7chWz8sV5KJuS4dPe9PUd6TqWe46W8Yelga1Nq28Mts72zqJhLJXDgMjSa6VCHlb9ZH3eo8XETWSqj8lNraejCAezFpkGpfJrPIl4xMhRPHSdO1WX7bZmVSLDFMZOwMfarb5/s3908/Distributions%20of%20F1%20scores%20over%202-year%20.jpeg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1844" data-original-width="3908" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgjwzwV6QYl4yIlWs1xdHz2HRiNi2I8WUaTGBVlVvA4guppIGpJ3RMj8ypE7chWz8sV5KJuS4dPe9PUd6TqWe46W8Yelga1Nq28Mts72zqJhLJXDgMjSa6VCHlb9ZH3eo8XETWSqj8lNraejCAezFpkGpfJrPIl4xMhRPHSdO1WX7bZmVSLDFMZOwMfarb5/s16000/Distributions%20of%20F1%20scores%20over%202-year%20.jpeg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Distributions of &lt;a href="https://en.wikipedia.org/wiki/F-score"&gt;F1 scores&lt;/a&gt; over 2-year return period events in 2,092 watersheds globally during the time period 2014-2023 from GloFAS (&lt;strong&gt;blue&lt;/strong&gt;) and our model (&lt;strong&gt;orange&lt;/strong&gt;) at different lead times. On average, our model is statistically as accurate as GloFAS nowcasts (0–day lead time) up to 5 days in advance over 2-year (shown) and 1-year, 5-year, and 10-year events (not shown).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
Additionally (not shown), our model achieves accuracies over larger and rarer extreme events, with precision and recall scores over 5-year return period events that are similar to or better than GloFAS accuracies over 1-year return period events. See the &lt;a href="https://www.nature.com/articles/s41586-024-07145-1"&gt;paper&lt;/a&gt; for more information.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Looking into the future&lt;/h2&gt;


&lt;p&gt;
The flood forecasting initiative is part of our &lt;a href="https://blog.google/outreach-initiatives/sustainability/google-ai-climate-change-solutions/"&gt;Adaptation and Resilience efforts&lt;/a&gt; and reflects Google's commitment&amp;nbsp;&lt;a href="https://research.google/teams/climate-and-sustainability/"&gt;to address climate change&lt;/a&gt; while helping global communities become more resilient. We believe that AI and ML will continue to play a critical role in helping advance science and research towards climate action.
&lt;/p&gt;

&lt;p&gt;
We actively &lt;a href="https://blog.google/outreach-initiatives/sustainability/4-flood-forecasting-collaboration-case-studies-show-how-ai-can-help-communities-in-need/"&gt;collaborate&lt;/a&gt; with several international aid organizations (e.g., the Centre for Humanitarian Data and the Red Cross) to provide actionable flood forecasts. Additionally, in an ongoing collaboration with the &lt;a href="https://wmo.int/"&gt;World Meteorological Organization&lt;/a&gt; (WMO) to &lt;a href="https://blog.google/outreach-initiatives/sustainability/early-warning-system-wmo-google/"&gt;support early warning systems&lt;/a&gt; for climate hazards, we are conducting a study to help understand how AI can help address real-world challenges faced by national flood forecasting agencies. 
&lt;/p&gt;

&lt;p&gt;
While the work presented here demonstrates a significant step forward in flood forecasting, future work  is needed to further expand flood forecasting coverage to more locations globally and other types of flood-related events and disasters, including flash floods and urban floods. We are looking forward to continuing collaborations with our partners in the academic and expert communities, local governments and the industry to reach these goals. 
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/4615278636568583418/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/using-ai-to-expand-global-access-to.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/4615278636568583418" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/4615278636568583418" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/using-ai-to-expand-global-access-to.html" rel="alternate" title="Using AI to expand global access to reliable flood forecasts" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgABDUlqCHMxNY-QfEftM_9yPy1z4jr1odB-_kSP79yjk6igtpPJNFIocQOKDRnZ3VLmqrI9tqX-dCHpcYtnSx96y9X9V9knp1CiAREvfgZX71D0XpWZNgPdZOI7aMW3POigHJ2rLeA1G1asaAPO3KIB3j0WzUr5C707I7p0L_itspYYEhYDhDTzd39tNUD/s72-c/Flood%20forecasting%20hero%20image.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-520087429457973735</id><published>2024-03-19T13:15:00.000-07:00</published><updated>2024-03-19T13:15:33.664-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="HCI"/><category scheme="http://www.blogger.com/atom/ns#" term="Multimodal Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="Self-Supervised Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="UI"/><title type="text">ScreenAI: A visual language model for UI and visually-situated language understanding</title><content type="html">&lt;span class="byline-author"&gt;Posted by Srinivas Sunkara and Gilles Baechler, Software Engineers, Google Research&lt;/span&gt;


&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhoXlMR7pAKRRnyKZT8C40i6mPX0KKNGT6AFNvFOFIhZ7BD0rXaU3NS_aqISTGq9S_d0zozgcO0HR_v3R6Msm4uUDkaBFsFVx-miaDL6L0UhSz1Is8_L_iFjtvNE5OX9HX98t92b3r-rLQfJG1RrzVW354NdVUlIJVRLdQ_l4dFYa1773J-tJligdvh7QsX/s320/ScreenAI%20-%20hero.jpeg" style="display: none;" /&gt;

&lt;p&gt;
Screen user interfaces (UIs) and infographics, such as charts, diagrams and tables, play important roles in human communication and human-machine interaction as they facilitate rich and interactive user experiences. UIs and infographics share similar design principles and visual language (e.g., icons and layouts), that offer an opportunity to build a single model that can understand, reason, and interact with these interfaces. However, because of their complexity and varied presentation formats, infographics and UIs present a unique modeling challenge.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
To that end, we introduce “&lt;a href="https://arxiv.org/abs/2402.04615"&gt;ScreenAI: A Vision-Language Model for UI and Infographics Understanding&lt;/a&gt;”. ScreenAI improves upon the &lt;a href="https://arxiv.org/abs/2305.18565"&gt;PaLI architecture&lt;/a&gt; with the flexible patching strategy from &lt;a href="https://arxiv.org/abs/2210.03347"&gt;pix2struct&lt;/a&gt;. We train ScreenAI on a unique mixture of datasets and tasks, including a novel Screen Annotation task that requires the model to identify UI element information (i.e., type, location and description) on a screen. These text annotations provide large language models (LLMs) with screen descriptions, enabling them to automatically generate question-answering (QA), UI navigation, and summarization training datasets at scale. At only 5B parameters, ScreenAI achieves state-of-the-art results on UI- and infographic-based tasks (&lt;a href="https://x-lance.github.io/WebSRC/"&gt;WebSRC&lt;/a&gt; and &lt;a href="https://github.com/aburns4/MoTIF"&gt;MoTIF&lt;/a&gt;), and best-in-class performance on &lt;a href="https://github.com/vis-nlp/ChartQA"&gt;Chart QA&lt;/a&gt;, &lt;a href="https://rrc.cvc.uab.es/?ch=17&amp;amp;com=evaluation&amp;amp;task=1"&gt;DocVQA&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2104.12756"&gt;InfographicVQA&lt;/a&gt; compared to models of similar size. We are also releasing three new datasets: &lt;a href="https://github.com/google-research-datasets/screen_qa?tab=readme-ov-file#screen-annotation-dataset-details"&gt;Screen Annotation&lt;/a&gt; to evaluate the layout understanding capability of the model, as well as &lt;a href="https://github.com/google-research-datasets/screen_qa/tree/main?tab=readme-ov-file#short_answers-directory"&gt;ScreenQA Short&lt;/a&gt; and &lt;a href="https://github.com/google-research-datasets/screen_qa?tab=readme-ov-file#complexqa" target="_blank"&gt;Complex ScreenQA&lt;/a&gt; for a more comprehensive evaluation of its QA capability. 
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;ScreenAI&lt;/h2&gt;


&lt;p&gt;
ScreenAI’s architecture is based on &lt;a href="https://arxiv.org/abs/2209.06794"&gt;PaLI&lt;/a&gt;, composed of a multimodal encoder block and an autoregressive decoder. The PaLI encoder uses a &lt;a href="https://arxiv.org/abs/2010.11929"&gt;vision transformer&lt;/a&gt; (ViT) that creates image embeddings and a multimodal encoder that takes the concatenation of the image and text embeddings as input. This flexible architecture allows ScreenAI to solve vision tasks that can be recast as text+image-to-text problems. 
&lt;/p&gt;

&lt;p&gt;
On top of the PaLI architecture, we employ a flexible patching strategy introduced in pix2struct. Instead of using a fixed-grid pattern, the grid dimensions are selected such that they preserve the native aspect ratio of the input image. This enables ScreenAI to work well across images of various aspect ratios. 
&lt;/p&gt;

&lt;p&gt;
The ScreenAI model is trained in two stages: a pre-training stage followed by a fine-tuning stage. First, self-supervised learning is applied to automatically generate data labels, which are then used to train ViT and the language model. ViT is frozen during the fine-tuning stage, where most data used is manually labeled by human raters. 
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjS1qatfLUw6BZZgkPxrv0Hx1pAPAehiF8q3kfA0BUyyPx4XXpwZRr75nYl99fTIQwLNmOHXhSBbpzHDnw6yQXZls1ZV-IE-d75jP5M02cRSZTYuU8FJBS4mubPzUPIuvcj_oqkEJcWtNWtnLmPZ3P1jJlDmc8GA1WNq00jUwl2o8gfLIIXlknrjy4z6y7Y/s1600/image6.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="583" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjS1qatfLUw6BZZgkPxrv0Hx1pAPAehiF8q3kfA0BUyyPx4XXpwZRr75nYl99fTIQwLNmOHXhSBbpzHDnw6yQXZls1ZV-IE-d75jP5M02cRSZTYuU8FJBS4mubPzUPIuvcj_oqkEJcWtNWtnLmPZ3P1jJlDmc8GA1WNq00jUwl2o8gfLIIXlknrjy4z6y7Y/s16000/image6.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;ScreenAI model architecture.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Data generation&lt;/h2&gt;


&lt;p&gt;
To create a pre-training dataset for ScreenAI, we first compile an extensive collection of screenshots from various devices, including desktops, mobile, and tablets. This is achieved by using &lt;a href="https://arxiv.org/abs/1910.10683" target="_blank"&gt;publicly accessible web pages&lt;/a&gt; and following the programmatic exploration approach used for the &lt;a href="https://dl.acm.org/doi/10.1145/3126594.3126651" target="_blank"&gt;RICO dataset&lt;/a&gt; for mobile apps. We then apply a layout annotator, based on the &lt;a href="https://arxiv.org/abs/2005.12872" target="_blank"&gt;DETR&lt;/a&gt; model, that identifies and labels a wide range of UI elements (e.g., image, pictogram, button, text) and their spatial relationships. Pictograms undergo further analysis using an &lt;a href="https://arxiv.org/abs/2210.02663" target="_blank"&gt;icon classifier&lt;/a&gt; capable of distinguishing 77 different icon types. This detailed classification is essential for interpreting the subtle information conveyed through icons. For icons that are not covered by the classifier, and for infographics and images, we use the PaLI image captioning model to generate descriptive captions that provide contextual information. We also apply an &lt;a href="https://cloud.google.com/use-cases/ocr" target="_blank"&gt;optical character recognition&lt;/a&gt; (OCR) engine to extract and annotate textual content on screen. We combine the OCR text with the previous annotations to create a detailed description of each screen.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj_wzxsb1U_PH17m3dG92ny7PpJjIYK39k1NQme1i5GM63tAd_OGdxMAV2_OQQVQSdkdyY1Tb3s8ibI2M3Kp1VpdNMsBr0ugBcBdL_r6dUwOwdfJfBMn3ae9Zl3zM2IpfZV654DFybMhMLimy0cuUNsnU5L8O2byu9eHmhdWcIvsb1t8AWi-tKNkXFq7Neo/s1747/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1055" data-original-width="1747" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj_wzxsb1U_PH17m3dG92ny7PpJjIYK39k1NQme1i5GM63tAd_OGdxMAV2_OQQVQSdkdyY1Tb3s8ibI2M3Kp1VpdNMsBr0ugBcBdL_r6dUwOwdfJfBMn3ae9Zl3zM2IpfZV654DFybMhMLimy0cuUNsnU5L8O2byu9eHmhdWcIvsb1t8AWi-tKNkXFq7Neo/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;A mobile app screenshot with generated annotations that include UI elements and their descriptions, e.g., &lt;code&gt;TEXT&lt;/code&gt; elements also contain the text content from OCR, &lt;code&gt;IMAGE&lt;/code&gt; elements contain image captions, &lt;code&gt;LIST_ITEMs&lt;/code&gt; contain all their child elements.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;LLM-based data generation&lt;/h3&gt;


&lt;p&gt;
We enhance the pre-training data's diversity using &lt;a href="https://blog.google/technology/ai/google-palm-2-ai-large-language-model/"&gt;PaLM 2&lt;/a&gt; to generate input-output pairs in a two-step process. First, screen annotations are generated using the technique outlined above, then we craft a prompt around this schema for the LLM to create synthetic data. This process requires prompt engineering and iterative refinement to find an effective prompt. We assess the generated data's quality through human validation against a quality threshold. 
&lt;/p&gt;


&lt;br /&gt;
&lt;pre class="prettyprint" style="margin-left: 40px; margin-right: 40px; white-space: pre-wrap;"&gt;&lt;font color="#008000"&gt;You only speak JSON. Do not write text that isn’t JSON.
You are given the following mobile screenshot, described in words. Can you generate 5 questions regarding the content of the screenshot as well as the corresponding short answers to them? 

The answer should be as short as possible, containing only the necessary information. Your answer should be structured as follows:
questions: [
{{question: the question,
    answer: the answer
}},
 ...
]

{THE SCREEN SCHEMA}
&lt;/font&gt;&lt;/pre&gt;
&lt;br /&gt;
&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;A sample prompt for QA data generation.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
By combining the natural language capabilities of LLMs with a structured schema, we simulate a wide range of user interactions and scenarios to generate synthetic, realistic tasks. In particular, we generate three categories of tasks:
&lt;/p&gt;

&lt;ul&gt;

&lt;li&gt;&lt;strong&gt;Question answering&lt;/strong&gt;: The model is asked to answer questions regarding the content of the screenshots, e.g., “When does the restaurant open?”

&lt;/li&gt;&lt;li&gt;&lt;strong&gt;Screen navigation&lt;/strong&gt;: The model is asked to convert a natural language utterance into an executable action on a screen, e.g., “Click the search button.”

&lt;/li&gt;&lt;li&gt;&lt;strong&gt;Screen summarization&lt;/strong&gt;: The model is asked to summarize the screen content in one or two sentences. 
&lt;/li&gt;
&lt;/ul&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiinxXWrVJQr3tZJ4-o3ipkdJriUqTRbi2CFWor4I2SpyMiswx6uZOM2ZJW0gZC75MXYshkjXPABvDuSnhR44ceNwDpkvaSLa4R3v4C-hEsnHdEc-JUUx31zZmDHDDwhWaMDqnD0wo6ibt7qBZfaYN_yx1myH77k-ruO9fjd33SiLnP0jLnjOfmhdEHbsR7/s1398/image3.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1398" data-original-width="1272" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiinxXWrVJQr3tZJ4-o3ipkdJriUqTRbi2CFWor4I2SpyMiswx6uZOM2ZJW0gZC75MXYshkjXPABvDuSnhR44ceNwDpkvaSLa4R3v4C-hEsnHdEc-JUUx31zZmDHDDwhWaMDqnD0wo6ibt7qBZfaYN_yx1myH77k-ruO9fjd33SiLnP0jLnjOfmhdEHbsR7/s16000/image3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Block diagram of our workflow for generating data for QA, summarization and navigation tasks using existing ScreenAI models and LLMs. Each task uses a custom prompt to emphasize desired aspects, like questions related to counting, involving reasoning, etc.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;img height="540" src="https://lh7-us.googleusercontent.com/LmUtXBMXK-zy_rMShHQ_Hk4vQeXu2Kpx8zfzjhE3uAREczbkbGTEjZ7OMTbqtB37lD4rF31xJsoWdVXNAXLbbM1Uc_01WZWmOfBg9RwyAUEToPpa1W38Pt117Zj5LrNfnxXqjXoAJDZd-zcAIgU4QSoBaAKsIrSi8_POI14F5hguN1NJL9a2RsrKg6WHz7w" style="margin-left: auto; margin-right: auto; margin-top: 0px;" width="705" /&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;LLM-generated data. Examples for screen QA, navigation and summarization. For navigation, the action bounding box is displayed in red on the screenshot.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Experiments and results&lt;/h2&gt;


&lt;p&gt;
As previously mentioned, ScreenAI is trained in two stages: pre-training and fine-tuning. Pre-training data labels are obtained using self-supervised learning and fine-tuning data labels comes from human raters. 
&lt;/p&gt;

&lt;p&gt;
We fine-tune ScreenAI using public QA, summarization, and navigation datasets and a variety of tasks related to UIs. For QA, we use well established benchmarks in the multimodal and document understanding field, such as &lt;a href="https://github.com/vis-nlp/ChartQA"&gt;ChartQA&lt;/a&gt;, &lt;a href="https://rrc.cvc.uab.es/?ch=17&amp;amp;com=evaluation&amp;amp;task=1"&gt;DocVQA&lt;/a&gt;, &lt;a href="https://rrc.cvc.uab.es/?ch=17&amp;amp;com=tasks"&gt;Multi page DocVQA&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2104.12756"&gt;InfographicVQA&lt;/a&gt;, &lt;a href="https://ocr-vqa.github.io/"&gt;OCR VQA&lt;/a&gt;, &lt;a href="https://x-lance.github.io/WebSRC/"&gt;Web SRC&lt;/a&gt; and &lt;a href="https://github.com/google-research-datasets/screen_qa"&gt;ScreenQA&lt;/a&gt;. For navigation, datasets used include &lt;a href="https://github.com/google-research-datasets/uibert/tree/main"&gt;Referring Expressions&lt;/a&gt;, &lt;a href="https://github.com/aburns4/MoTIF"&gt;MoTIF&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2209.15099"&gt;Mug&lt;/a&gt;, and &lt;a href="https://github.com/google-research/google-research/tree/master/android_in_the_wild"&gt;Android in the Wild&lt;/a&gt;. Finally, we use &lt;a href="https://github.com/google-research-datasets/screen2words"&gt;Screen2Words&lt;/a&gt; for screen summarization and &lt;a href="https://paperswithcode.com/paper/widget-captioning-generating-natural-language/review/"&gt;Widget Captioning&lt;/a&gt; for describing specific UI elements. Along with the fine-tuning datasets, we  evaluate the fine-tuned ScreenAI model using three novel benchmarks:
&lt;/p&gt;

&lt;ol&gt;

&lt;li&gt;Screen Annotation: Enables the evaluation model layout annotations and spatial understanding capabilities.

&lt;/li&gt;&lt;li&gt;ScreenQA Short: A variation of ScreenQA, where its ground truth answers have been shortened to contain only the relevant information that better aligns with other QA tasks.

&lt;/li&gt;&lt;li&gt;Complex ScreenQA: Complements ScreenQA Short with more difficult questions (counting, arithmetic, comparison, and non-answerable questions) and contains screens with various aspect ratios.
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;
The fine-tuned ScreenAI model achieves state-of-the-art results on various UI and infographic-based tasks (&lt;a href="https://x-lance.github.io/WebSRC/"&gt;WebSRC&lt;/a&gt; and &lt;a href="https://github.com/aburns4/MoTIF"&gt;MoTIF&lt;/a&gt;) and best-in-class performance on &lt;a href="https://github.com/vis-nlp/ChartQA"&gt;Chart QA&lt;/a&gt;, &lt;a href="https://rrc.cvc.uab.es/?ch=17&amp;amp;com=evaluation&amp;amp;task=1"&gt;DocVQA&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2104.12756"&gt;InfographicVQA&lt;/a&gt; compared to models of similar size. ScreenAI achieves competitive performance on Screen2Words and OCR-VQA. Additionally, we report results on the new benchmark datasets introduced to serve as a baseline for further research.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEijJAw824LdVbrFU3c7oerx9Ik86dWnuQ2NqliLpUZLp6U-9pDxZKsw10VSMfYOSwns-GWJRdSCj3UmyxytOZxfoM64psBSKCjLYa-3zkXDt8mGvFbNpydwS1Ya2dhDeYfihWL1mVCyTWIzdgfblxawoxukWW1vLLwfNWMNKQ64B8wUM5SlNKgegdGxXlr7/s1183/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1137" data-original-width="1183" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEijJAw824LdVbrFU3c7oerx9Ik86dWnuQ2NqliLpUZLp6U-9pDxZKsw10VSMfYOSwns-GWJRdSCj3UmyxytOZxfoM64psBSKCjLYa-3zkXDt8mGvFbNpydwS1Ya2dhDeYfihWL1mVCyTWIzdgfblxawoxukWW1vLLwfNWMNKQ64B8wUM5SlNKgegdGxXlr7/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparing model performance of ScreenAI with state-of-the-art (SOTA) models of similar size.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
Next, we examine ScreenAI’s scaling capabilities and observe that across all tasks, increasing the model size improves performances and the improvements have not saturated at the largest size.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKNMvTyz1RhM0wqgn7eAGB9Lev3YUhKhHrcAmJt3SB1Gi6ozIaxHoPzAj-bm6II-_91viG2FXrfNZiiwSSI_YNQGwKGyO6YkAW05Cfl9oys869f7DMyJcthlj6c0CLwzMAGP8HM9AmxdCK92d4PL2Ujz-tI4CZsQOlzlecMLgElWBjl9FZtj-zWIWata2k/s1999/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="523" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKNMvTyz1RhM0wqgn7eAGB9Lev3YUhKhHrcAmJt3SB1Gi6ozIaxHoPzAj-bm6II-_91viG2FXrfNZiiwSSI_YNQGwKGyO6YkAW05Cfl9oys869f7DMyJcthlj6c0CLwzMAGP8HM9AmxdCK92d4PL2Ujz-tI4CZsQOlzlecMLgElWBjl9FZtj-zWIWata2k/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Model performance increases with size, and the performance has not saturated even at the largest size of 5B params.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;br /&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
We introduce the ScreenAI model along with a unified representation that enables us to develop self-supervised learning tasks leveraging data from all these domains. We also illustrate the impact of data generation using LLMs and investigate improving model performance on specific aspects with modifying the training mixture. We apply all of these techniques to build multi-task trained models that perform competitively with state-of-the-art approaches on a number of public benchmarks. However, we also note that our approach still lags behind large models and further research is needed to bridge this gap.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;This project is the result of joint work with Maria Wang, Fedir Zubach, Hassan Mansoor, Vincent Etter, Victor Carbune, Jason Lin, Jindong Chen and Abhanshu Sharma. We thank Fangyu Liu, Xi Chen, Efi Kokiopoulou, Jesse Berent, Gabriel Barcik, Lukas Zilka, Oriana Riva, Gang Li,Yang Li, Radu Soricut, and Tania Bedrax-Weiss for their insightful feedback and discussions, along with Rahul Aralikatte, Hao Cheng and Daniel Kim for their support in data preparation. We also thank Jay Yagnik, Blaise Aguera y Arcas, Ewa Dominowska, David Petrou, and Matt Sharifi for their leadership, vision and support. We are very grateful toTom Small for helping us create the animation in this post.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/520087429457973735/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/screenai-visual-language-model-for-ui.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/520087429457973735" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/520087429457973735" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/screenai-visual-language-model-for-ui.html" rel="alternate" title="ScreenAI: A visual language model for UI and visually-situated language understanding" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhoXlMR7pAKRRnyKZT8C40i6mPX0KKNGT6AFNvFOFIhZ7BD0rXaU3NS_aqISTGq9S_d0zozgcO0HR_v3R6Msm4uUDkaBFsFVx-miaDL6L0UhSz1Is8_L_iFjtvNE5OX9HX98t92b3r-rLQfJG1RrzVW354NdVUlIJVRLdQ_l4dFYa1773J-tJligdvh7QsX/s72-c/ScreenAI%20-%20hero.jpeg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-4328167517765145678</id><published>2024-03-19T08:00:00.000-07:00</published><updated>2024-03-19T08:00:00.150-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="crowd-sourcing"/><category scheme="http://www.blogger.com/atom/ns#" term="datasets"/><category scheme="http://www.blogger.com/atom/ns#" term="Diversity"/><category scheme="http://www.blogger.com/atom/ns#" term="Health"/><title type="text">SCIN: A new resource for representative dermatology images</title><content type="html">&lt;span class="byline-author"&gt;Posted by Pooja Rao, Research Scientist, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi_fSTMFxLAMHLJ0rw7OAddGSPMW2tRl8kmTr2mWiiJunKxB8ZflMJeWkBmB5IqCD2LvRoikpN7OYnZO3CdKpArGn32b4o-T8ZD6XCPxmUBtE1-sPBi6J05y5_UrfbWSMTjNpldKYzM3xjXoC0iWU7q_a7Ktfi2S1hVHLY8uq1986yp_pgEjQn3elNuSUbJ/s1600/SCINHero.png" style="display: none;" /&gt;

&lt;p&gt;
Health datasets play a crucial role in research and medical education, but it can be challenging to create a dataset that represents the real world. For example, dermatology conditions are diverse in their appearance and severity and manifest differently across skin tones. Yet, existing dermatology image datasets often lack representation of everyday conditions (like rashes, allergies and infections) and skew towards lighter skin tones. Furthermore, race and ethnicity information is frequently missing, hindering our ability to assess disparities or create solutions.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;


&lt;p&gt;
To address these limitations, we are releasing the &lt;a href="https://github.com/google-research-datasets/scin"&gt;Skin Condition Image Network (SCIN) dataset&lt;/a&gt; in collaboration with physicians at &lt;a href="https://med.stanford.edu/"&gt;Stanford Medicine&lt;/a&gt;. We designed SCIN to reflect the broad range of concerns that people search for online, supplementing the types of conditions typically found in clinical datasets. It contains images across various skin tones and body parts, helping to ensure that future AI tools work effectively for all. We've made &lt;a href="https://github.com/google-research-datasets/scin"&gt;the SCIN dataset&lt;/a&gt; freely available as an open-access resource for researchers, educators, and developers, and have taken careful steps to protect contributor privacy.   
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi-lvUDxsY1bC8xXeRFKGtdyRiCk25knKK3tKzW2dCVtfvzFMUYvM7laqOBS0yP6Dnur5Fd945gbC96OMoiJ2nvguO6uguDArYkvnLUz5glvPlNpI1THL_bctcQCGlR670V4szxkHlcdvAJbP7T8HS7U3ASnHh_sWhSxoKJSsLN-1IPUpysj5ErdHaduz5r/s1327/image1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1118" data-original-width="1327" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi-lvUDxsY1bC8xXeRFKGtdyRiCk25knKK3tKzW2dCVtfvzFMUYvM7laqOBS0yP6Dnur5Fd945gbC96OMoiJ2nvguO6uguDArYkvnLUz5glvPlNpI1THL_bctcQCGlR670V4szxkHlcdvAJbP7T8HS7U3ASnHh_sWhSxoKJSsLN-1IPUpysj5ErdHaduz5r/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Example set of images and metadata from the SCIN dataset.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Dataset composition&lt;/h2&gt;


&lt;p&gt;
The SCIN dataset currently contains over 10,000 images of skin, nail, or hair conditions, directly contributed by individuals experiencing them. All contributions were made voluntarily with informed consent by individuals in the US, under an institutional-review board approved study. To provide context for retrospective dermatologist labeling, contributors were asked to take images both close-up and from slightly further away. They were given the option to self-report demographic information and &lt;a href="https://en.wikipedia.org/wiki/Fitzpatrick_scale"&gt;tanning propensity&lt;/a&gt; (self-reported Fitzpatrick Skin Type, i.e., sFST), and to describe the texture, duration and symptoms related to their concern.
&lt;/p&gt;
&lt;p&gt;
One to three dermatologists labeled each contribution with up to five dermatology conditions, along with a confidence score for each label. The SCIN dataset contains these individual labels, as well as an aggregated and weighted differential diagnosis derived from them that could be useful for model testing or training. These labels were assigned retrospectively and are not equivalent to a clinical diagnosis, but they allow us to compare the distribution of dermatology conditions in the SCIN dataset with existing datasets.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi7oYE7nKEvgBaW6SEHfGFzCrhnKqX5w86_7ujHMbpMENOByxcUTgAzXJrZCgv6kbDVmTN8NmKSBBSvF4XkWKcKf5DT_b3A5D50ZpAr-93i3a69KUFOZy54diZxH_wcf1PeKdFlRbEe_OZODxS0N4ZrHSaiki8ZslUfFUatw4w-0p0zzD4GRwlqgmPLR6gw/s1851/image2.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="775" data-original-width="1851" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi7oYE7nKEvgBaW6SEHfGFzCrhnKqX5w86_7ujHMbpMENOByxcUTgAzXJrZCgv6kbDVmTN8NmKSBBSvF4XkWKcKf5DT_b3A5D50ZpAr-93i3a69KUFOZy54diZxH_wcf1PeKdFlRbEe_OZODxS0N4ZrHSaiki8ZslUfFUatw4w-0p0zzD4GRwlqgmPLR6gw/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;The SCIN dataset contains largely allergic, inflammatory and infectious conditions while datasets from clinical sources focus on benign and malignant &lt;a href="https://en.wikipedia.org/wiki/Neoplasm"&gt;neoplasms&lt;/a&gt;.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
While many existing dermatology datasets focus on malignant and benign tumors and are intended to assist with skin cancer diagnosis, the SCIN dataset consists largely of common allergic, inflammatory, and infectious conditions. The majority of images in the SCIN dataset show early-stage concerns — more than half arose less than a week before the photo, and 30% arose less than a day before the image was taken. Conditions within this time window are seldom seen within the health system and therefore are underrepresented in existing dermatology datasets. 
&lt;/p&gt;
&lt;p&gt;
We also obtained dermatologist estimates of Fitzpatrick Skin Type (estimated FST or eFST) and layperson labeler estimates of &lt;a href="https://en.wikipedia.org/wiki/Monk_Skin_Tone_Scale"&gt;Monk Skin Tone&lt;/a&gt; (eMST) for the images. This allowed comparison of the skin condition and skin type distributions to those in existing dermatology datasets. Although we did not selectively target any skin types or skin tones, the SCIN dataset has a balanced Fitzpatrick skin type distribution (with more of Types 3, 4, 5, and 6) compared to similar datasets from clinical sources. 
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiNnhVt5yEHKdMsi-tMYH9Q9oITruBrlrrfaQk8oopWHBr1qq6lfPrZnLrav-y2w7i9vgptlNDw_xKX3J8W0fZ1NfU-cOeINXc6bgf2vHJL3bc-UCWA7T846QQHkTvob6QbB3sR0HbwI9Vms3oXtAZ_zbrd4w_eAKLTo5-obYoG3A2urPmiF7RS5GcgVRhH/s1851/image3.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="620" data-original-width="1851" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiNnhVt5yEHKdMsi-tMYH9Q9oITruBrlrrfaQk8oopWHBr1qq6lfPrZnLrav-y2w7i9vgptlNDw_xKX3J8W0fZ1NfU-cOeINXc6bgf2vHJL3bc-UCWA7T846QQHkTvob6QbB3sR0HbwI9Vms3oXtAZ_zbrd4w_eAKLTo5-obYoG3A2urPmiF7RS5GcgVRhH/s16000/image3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Self-reported and dermatologist-estimated Fitzpatrick Skin Type distribution in the SCIN dataset compared with existing un-enriched dermatology datasets &lt;a href="https://github.com/mattgroh/fitzpatrick17k"&gt;(Fitzpatrick17k&lt;/a&gt;, &lt;a href="https://www.fc.up.pt/addi/ph2%20database.html"&gt;PH²&lt;/a&gt;, &lt;a href="https://www.it.pt/AutomaticPage?id=3459"&gt;SKINL2&lt;/a&gt;, and&lt;a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7479321/"&gt; PAD-UFES-20&lt;/a&gt;).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
The &lt;a href="https://en.wikipedia.org/wiki/Fitzpatrick_scale"&gt;Fitzpatrick Skin Type&lt;/a&gt; scale was originally developed as a photo-typing scale to measure the response of skin types to UV radiation, and it is widely used in dermatology research. The Monk Skin Tone scale is a newer 10-shade scale that measures skin tone rather than skin phototype, capturing more nuanced differences between the darker skin tones. While neither scale was intended for retrospective estimation using images, the inclusion of these labels is intended to enable future research into skin type and tone representation in dermatology. For example, the SCIN dataset provides an initial benchmark for the distribution of these skin types and tones in the US population.
&lt;/p&gt;
&lt;p&gt;
The SCIN dataset has a high representation of women and younger individuals, likely reflecting a combination of factors. These could include differences in skin condition incidence, propensity to seek health information online, and variations in willingness to contribute to research across demographics.
&lt;/p&gt;



&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Crowdsourcing method&lt;/h2&gt;


&lt;p&gt;
To create the SCIN dataset, we used a novel crowdsourcing method, which we describe in the accompanying &lt;a href="https://arxiv.org/abs/2402.18545"&gt;research paper&lt;/a&gt; co-authored with investigators at &lt;a href="https://med.stanford.edu/"&gt;Stanford Medicine&lt;/a&gt;. This approach empowers individuals to play an active role in healthcare research. It allows us to reach people at earlier stages of their health concerns, potentially before they seek formal care. Crucially, this method uses advertisements on web search result pages — the starting point for many people’s health journey — to connect with participants. 
&lt;/p&gt;
&lt;p&gt;
Our results demonstrate that crowdsourcing can yield a high-quality dataset with a low spam rate. Over 97.5% of contributions were genuine images of skin conditions. After performing further filtering steps to exclude images that were out of scope for the SCIN dataset and to remove duplicates, we were able to release nearly 90% of the contributions received over the 8-month study period. Most images were sharp and well-exposed. Approximately half of the contributions include self-reported demographics, and 80% contain self-reported information relating to the skin condition, such as texture, duration, or other symptoms. We found that dermatologists’ ability to retrospectively assign a differential diagnosis depended more on the availability of self-reported information than on image quality.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj1QMRINpok_qmh5jjtktgqytapBRfHWDFxLKffzY9L_jG8uE8oJXA7QwtGY76gPksw5EH0yLuO7Ihk3IitXQDCjQ54DXlxFtpClbIIZzZAb6fDufHR-aW1m81cAMBqxmPIZsN8p3VYlys8b9cczZOzI-VB9d1Nwzk8nCnPTSCDwwh1fmEf4Q8DRdJHo6dR/s1999/image4.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1610" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj1QMRINpok_qmh5jjtktgqytapBRfHWDFxLKffzY9L_jG8uE8oJXA7QwtGY76gPksw5EH0yLuO7Ihk3IitXQDCjQ54DXlxFtpClbIIZzZAb6fDufHR-aW1m81cAMBqxmPIZsN8p3VYlys8b9cczZOzI-VB9d1Nwzk8nCnPTSCDwwh1fmEf4Q8DRdJHo6dR/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Dermatologist confidence in their labels (scale from 1-5) depended on the availability of self-reported demographic and symptom information.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
While perfect image de-identification can never be guaranteed, protecting the privacy of individuals who contributed their images was a top priority when creating the SCIN dataset. Through informed consent, contributors were made aware of potential re-identification risks and advised to avoid uploading images with identifying features. Post-submission privacy protection measures included manual redaction or cropping to exclude potentially identifying areas, reverse image searches to exclude publicly available copies and metadata removal or aggregation. The SCIN &lt;a href="https://github.com/google-research-datasets/scin?tab=License-1-ov-file#readme"&gt;Data Use License&lt;/a&gt; prohibits attempts to re-identify contributors.
&lt;/p&gt;
&lt;p&gt;
We hope the SCIN dataset will be a helpful resource for those working to advance inclusive dermatology research, education, and AI tool development. By demonstrating an alternative to traditional dataset creation methods, SCIN paves the way for more representative datasets in areas where self-reported data or retrospective labeling is feasible. 
&lt;/p&gt;



&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We are grateful to all our co-authors Abbi Ward, Jimmy Li, Julie Wang, Sriram Lakshminarasimhan, Ashley Carrick, Bilson Campana, Jay Hartford, Pradeep Kumar S, Tiya Tiyasirisokchai, Sunny Virmani, Renee Wong, Yossi Matias, Greg S. Corrado, Dale R. Webster, Dawn Siegel (Stanford Medicine), Steven Lin (Stanford Medicine), Justin Ko (Stanford Medicine), Alan Karthikesalingam and Christopher Semturs. We also thank Yetunde Ibitoye, Sami Lachgar, Lisa Lehmann, Javier Perez, Margaret Ann Smith (Stanford Medicine), Rachelle Sico, Amit Talreja, Annisah Um’rani and Wayne Westerlind for their essential contributions to this work. Finally, we are grateful to Heather Cole-Lewis, Naama Hammel, Ivor Horn, Michael Howell, Yun Liu, and Eric Teasley for their insightful comments on the study design and manuscript. &lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/4328167517765145678/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/scin-new-resource-for-representative.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/4328167517765145678" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/4328167517765145678" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/scin-new-resource-for-representative.html" rel="alternate" title="SCIN: A new resource for representative dermatology images" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi_fSTMFxLAMHLJ0rw7OAddGSPMW2tRl8kmTr2mWiiJunKxB8ZflMJeWkBmB5IqCD2LvRoikpN7OYnZO3CdKpArGn32b4o-T8ZD6XCPxmUBtE1-sPBi6J05y5_UrfbWSMTjNpldKYzM3xjXoC0iWU7q_a7Ktfi2S1hVHLY8uq1986yp_pgEjQn3elNuSUbJ/s72-c/SCINHero.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-757976001746578714</id><published>2024-03-18T11:41:00.000-07:00</published><updated>2024-03-18T12:01:42.865-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Computer Vision"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><title type="text">MELON: Reconstructing 3D objects from images with unknown poses</title><content type="html">&lt;span class="byline-author"&gt;Posted by Mark Matthews, Senior Software Engineer, and Dmitry Lagun, Research Scientist, Google Research&lt;/span&gt;


&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh8LjCbKjfNXVUyCpGiZysx_pNF5BK8p5VBCJXXPaz_Bb75CW-33weoMh0YaNcn4AdmGN-Pufd_XlsRzo2MWZLQxqgtri7Nip9tXoGX0CritvRKF-63StOWxp_gVaY-MTnOk9IvJdVt_CczVR6Ip_R8Yv32MHTw2-FckCTF4UOFrgMyq3PCPCkZaZ-nyMcE/s320/MELON%20HERO.jpg" style="display: none;" /&gt;

&lt;p&gt;
A person's prior experience and understanding of the world generally enables them to easily infer what an object looks like in whole, even if only looking at a few 2D pictures of it. Yet the capacity for a computer to reconstruct the shape of an object in 3D given only a few images has remained a difficult algorithmic problem for years. This fundamental computer vision task has applications ranging from the creation of e-commerce 3D models to autonomous vehicle navigation. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
A key part of the problem is how to determine the exact positions from which images were taken, known as &lt;em&gt;pose inference&lt;/em&gt;. If camera poses are known, a range of successful techniques — such as &lt;a href="https://www.matthewtancik.com/nerf"&gt;neural radiance fields&lt;/a&gt; (NeRF) or &lt;a href="https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/"&gt;3D Gaussian Splatting&lt;/a&gt; — can reconstruct an object in 3D. But if these poses are not available, then we face a difficult “chicken and egg” problem where we could determine the poses if we knew the 3D object, but we can’t reconstruct the 3D object until we know the camera poses. The problem is made harder by pseudo-symmetries — i.e., many objects look similar when viewed from different angles. For example, square objects like a chair tend to look similar every 90° rotation. Pseudo-symmetries of an object can be revealed by rendering it on a turntable from various angles and plotting its photometric &lt;a href="https://en.wikipedia.org/wiki/Self-similarity"&gt;self-similarity&lt;/a&gt; map. 
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjt0nP5M8f5UodttSIPoY5t0JRXEuLosGgock3B0lyOzIn4icGF5jwVuxgX0PiRqc0kBbJ36CLiGA3KPrmaQbjKElGeHrsSRmkpDppU9abE84nuYu9MquqE3gULDzz_INDutmL2i1Wv3_tUpTh5U9UwSck9YRUeVyg-md2GByg3EQYYy7Vs_aeTEk5akpSo/s1764/image5.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="923" data-original-width="1764" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjt0nP5M8f5UodttSIPoY5t0JRXEuLosGgock3B0lyOzIn4icGF5jwVuxgX0PiRqc0kBbJ36CLiGA3KPrmaQbjKElGeHrsSRmkpDppU9abE84nuYu9MquqE3gULDzz_INDutmL2i1Wv3_tUpTh5U9UwSck9YRUeVyg-md2GByg3EQYYy7Vs_aeTEk5akpSo/s16000/image5.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Self-Similarity map of a toy truck model. &lt;strong&gt;Left:&lt;/strong&gt; The model is rendered on a turntable from various &lt;a href="https://en.wikipedia.org/wiki/Azimuth"&gt;azimuthal angles&lt;/a&gt;, θ. &lt;strong&gt;Right:&lt;/strong&gt; The average &lt;a href="https://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm"&gt;L2&lt;/a&gt; RGB similarity of a rendering from θ with that of θ*. The pseudo-similarities are indicated by the dashed red lines.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
The diagram above only visualizes one dimension of rotation. It becomes even more complex (and difficult to visualize) when introducing more degrees of freedom. Pseudo-symmetries make the problem &lt;em&gt;ill-posed&lt;/em&gt;, with naïve approaches often converging to local minima. In practice, such an approach might mistake the back view as the front view of an object, because they share a similar silhouette. Previous techniques (such as &lt;a href="https://chenhsuanlin.bitbucket.io/bundle-adjusting-NeRF/"&gt;BARF&lt;/a&gt; or &lt;a href="https://arxiv.org/abs/2205.15768"&gt;SAMURAI&lt;/a&gt;) side-step this problem by relying on an initial pose estimate that starts close to the global minima. But how can we approach this if those aren’t available?
&lt;/p&gt;

&lt;p&gt;
Methods, such as &lt;a href="https://openaccess.thecvf.com/content/ICCV2021/papers/Meng_GNeRF_GAN-Based_Neural_Radiance_Field_Without_Posed_Camera_ICCV_2021_paper.pdf"&gt;GNeRF&lt;/a&gt; and &lt;a href="https://dl.acm.org/doi/10.1145/3503161.3548078"&gt;VMRF&lt;/a&gt; leverage &lt;a href="https://en.wikipedia.org/wiki/Generative_adversarial_network"&gt;generative adversarial networks&lt;/a&gt; (GANs) to overcome the problem. These techniques have the ability to artificially “amplify” a limited number of training views, aiding reconstruction. GAN techniques, however, often have complex, sometimes unstable, training processes, making robust and reliable convergence difficult to achieve in practice. A range of other successful methods, such as &lt;a href="https://openaccess.thecvf.com/content/CVPR2023/html/Sinha_SparsePose_Sparse-View_Camera_Pose_Regression_and_Refinement_CVPR_2023_paper.html"&gt;SparsePose&lt;/a&gt; or &lt;a href="https://rust-paper.github.io/"&gt;RUST&lt;/a&gt;, can infer poses from a limited number views, but require pre-training on a large dataset of posed images, which aren’t always available, and can suffer from “domain-gap” issues when inferring poses for different types of images.
&lt;/p&gt;

&lt;p&gt;
In “&lt;a href="https://arxiv.org/abs/2303.08096"&gt;MELON: NeRF with Unposed Images in SO(3)&lt;/a&gt;”, spotlighted at &lt;a href="https://3dvconf.github.io/2024/"&gt;3DV 2024&lt;/a&gt;, we present a technique that can determine object-centric camera poses entirely from scratch while reconstructing the object in 3D. &lt;a href="https://melon-nerf.github.io/"&gt;MELON&lt;/a&gt; (Modulo Equivalent Latent Optimization of NeRF) is one of the first techniques that can do this without initial pose camera estimates, complex training schemes or pre-training on labeled data. MELON is a relatively simple technique that can easily be integrated into existing NeRF methods. We demonstrate that MELON can reconstruct a NeRF from unposed images with state-of-the-art accuracy while requiring as few as 4–6 images of an object. 
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;MELON&lt;/h2&gt;


&lt;p&gt;
We leverage two key techniques to aid convergence of this ill-posed problem. The first is a very lightweight, dynamically trained &lt;a href="https://en.wikipedia.org/wiki/Convolutional_neural_network"&gt;convolutional neural network&lt;/a&gt; (CNN) encoder that regresses camera poses from training images. We pass a downscaled training image to a four layer CNN that infers the camera pose. This CNN is initialized from noise and requires no pre-training. Its capacity is so small that it forces similar looking images to similar poses, providing an implicit regularization greatly aiding convergence.
&lt;/p&gt;

&lt;p&gt;
The second technique is a &lt;em&gt;modulo loss&lt;/em&gt; that simultaneously considers pseudo symmetries of an object. We render the object from a fixed set of viewpoints for each training image, backpropagating the loss only through the view that best fits the training image. This effectively considers the plausibility of multiple views for each image. In practice, we find &lt;em&gt;N&lt;/em&gt;=2 views (viewing an object from the other side) is all that’s required in most cases, but sometimes get better results with &lt;em&gt;N&lt;/em&gt;=4 for square objects.
&lt;/p&gt;

&lt;p&gt;
These two techniques are integrated into standard NeRF training, except that instead of fixed camera poses, poses are inferred by the CNN and duplicated by the modulo loss. Photometric gradients back-propagate through the best-fitting cameras into the CNN. We observe that cameras generally converge quickly to globally optimal poses (see animation below). After training of the neural field, MELON can synthesize novel views using standard NeRF rendering methods.
&lt;/p&gt;

&lt;p&gt;
We simplify the problem by using the &lt;a href="https://github.com/bmild/nerf"&gt;NeRF-Synthetic&lt;/a&gt; dataset, a popular benchmark for NeRF research and common in the pose-inference literature. This synthetic dataset has cameras at precisely fixed distances and a consistent “up” orientation, requiring us to infer only the &lt;a href="https://en.wikipedia.org/wiki/Spherical_coordinate_system"&gt;polar coordinates&lt;/a&gt; of the camera. This is the same as an object at the center of a globe with a camera always pointing at it, moving along the surface. We then only need the latitude and longitude (2 degrees of freedom) to specify the camera pose.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhEjisRopoeGPgbCRa3sQ7hmBUtnfI6TRapBD7Yn96xeDA_LxzTayiw3DMijPHS0ovkLVTcQGpp2_gAyA_P5BCPwXuEcz7lApC8WQbGfMvj_aAxShjgsmcklf_-4ekgbFH6VZ92Ey3Ta4XAhZvEdc00D2o7SzPIOSnFAj8CgrdmdJunijsGaw1Zx46b94wk/s1315/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="395" data-original-width="1315" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhEjisRopoeGPgbCRa3sQ7hmBUtnfI6TRapBD7Yn96xeDA_LxzTayiw3DMijPHS0ovkLVTcQGpp2_gAyA_P5BCPwXuEcz7lApC8WQbGfMvj_aAxShjgsmcklf_-4ekgbFH6VZ92Ey3Ta4XAhZvEdc00D2o7SzPIOSnFAj8CgrdmdJunijsGaw1Zx46b94wk/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;MELON uses a dynamically trained lightweight CNN encoder that predicts a pose for each image. Predicted poses are replicated by the &lt;em&gt;modulo loss, &lt;/em&gt;which only penalizes the smallest L2 distance from the ground truth color. At evaluation time, the neural field can be used to generate novel views.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;br /&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Results&lt;/h2&gt;


&lt;p&gt;
We compute two key metrics to evaluate MELON’s performance on the NeRF Synthetic dataset. The error in orientation between the ground truth and inferred poses can be quantified as a single angular error that we average across all training images, the pose error. We then test the accuracy of MELON’s rendered objects from novel views by measuring the &lt;a href="https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio"&gt;peak signal-to-noise ratio&lt;/a&gt; (PSNR) against held out test views. We see that MELON quickly converges to the approximate poses of most cameras within the first 1,000 steps of training, and achieves a competitive PSNR of 27.5 dB after 50k steps. 
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjU5wdw89PfwRbvZeaIWLM3rNEAo69__A-ovDwB5x8emIkAGZq05FgF-wDMNlkXPS6tOcC_0NJVD4Glq8eX02yb3CDIiqXbadI4lnvcZ_MI9sHUkz8risxP1orPA8ZnTZUq-PcRLPoEc_AmFuARCokXHQlTOv_q35TH1tivuK2PpA54hO7q7kh_M8ZynO-J/s960/image1.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="480" data-original-width="960" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjU5wdw89PfwRbvZeaIWLM3rNEAo69__A-ovDwB5x8emIkAGZq05FgF-wDMNlkXPS6tOcC_0NJVD4Glq8eX02yb3CDIiqXbadI4lnvcZ_MI9sHUkz8risxP1orPA8ZnTZUq-PcRLPoEc_AmFuARCokXHQlTOv_q35TH1tivuK2PpA54hO7q7kh_M8ZynO-J/s16000/image1.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Convergence of MELON on a toy truck model during optimization. &lt;strong&gt;Left&lt;/strong&gt;: Rendering of the NeRF. &lt;strong&gt;Right&lt;/strong&gt;: Polar plot of predicted (blue &lt;em&gt;x&lt;/em&gt;), and ground truth (red dot) cameras.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
MELON achieves similar results for other scenes in the NeRF Synthetic dataset.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhWEC7CE_iWu1QZ_jgEUHHCEdqaUBMO7cK-1DZuHaZRDq4Y59_CriUlb_aOSJP5psB6Cbs1E41mm81EsfwVM0zAUojRKToWwiDmPfaWFPr2UGqf6F4n3P8ZpgYxiqyWIgst6op3Fhsbu0nlR727zLVV38KqJvNFY_KDeoJbdOjJFpHjLZkEd95Z9TqSg4R_/s1999/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="644" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhWEC7CE_iWu1QZ_jgEUHHCEdqaUBMO7cK-1DZuHaZRDq4Y59_CriUlb_aOSJP5psB6Cbs1E41mm81EsfwVM0zAUojRKToWwiDmPfaWFPr2UGqf6F4n3P8ZpgYxiqyWIgst6op3Fhsbu0nlR727zLVV38KqJvNFY_KDeoJbdOjJFpHjLZkEd95Z9TqSg4R_/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Reconstruction quality comparison between ground-truth (GT) and MELON on NeRF-Synthetic scenes after 100k training steps.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Noisy images&lt;/h3&gt;


&lt;p&gt;
MELON also works well when performing &lt;a href="https://en.wikipedia.org/wiki/View_synthesis"&gt;novel view synthesis&lt;/a&gt; from extremely noisy, unposed images. We add varying amounts, &lt;em&gt;σ&lt;/em&gt;, of &lt;a href="https://en.wikipedia.org/wiki/Additive_white_Gaussian_noise"&gt;white Gaussian noise&lt;/a&gt; to the training images. For example, the object in &lt;em&gt;σ&lt;/em&gt;=1.0 below is impossible to make out, yet MELON can determine the pose and generate novel views of the object. 
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgHKYFcj-CKKc5kUvfsoOD5rBTp2QMnd3CdYiVzXjMClNwJrcgSrvIZngAdLgxUthE-aiXx5NapxcMx66i-Bi9RhC0zTRVkA0R8fj2A7lOnIdFDIE3YkTh_hWO2PhPa0FjYWYHuNUuae_tPhsrmVHJAkCeeI1f0ooJGe44KgpcO7jVNyLcnUvwtMX-KpJdD/s1182/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="568" data-original-width="1182" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgHKYFcj-CKKc5kUvfsoOD5rBTp2QMnd3CdYiVzXjMClNwJrcgSrvIZngAdLgxUthE-aiXx5NapxcMx66i-Bi9RhC0zTRVkA0R8fj2A7lOnIdFDIE3YkTh_hWO2PhPa0FjYWYHuNUuae_tPhsrmVHJAkCeeI1f0ooJGe44KgpcO7jVNyLcnUvwtMX-KpJdD/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Novel view synthesis from noisy unposed 128×128 images. Top: Example of noise level present in training views. Bottom: Reconstructed model from noisy training views and mean angular pose error.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
This perhaps shouldn’t be too surprising, given that techniques like &lt;a href="https://bmild.github.io/rawnerf/"&gt;RawNeRF&lt;/a&gt; have demonstrated NeRF’s excellent de-noising capabilities with known camera poses. The fact that MELON works for noisy images of unknown camera poses so robustly was unexpected. 
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
We present MELON, a technique that can determine object-centric camera poses to reconstruct objects in 3D without the need for approximate pose initializations, complex GAN training schemes or pre-training on labeled data. MELON is a relatively simple technique that can easily be integrated into existing NeRF methods. Though we only demonstrated MELON on synthetic images we are adapting our technique to work in real world conditions. See the &lt;a href="https://arxiv.org/abs/2303.08096"&gt;paper&lt;/a&gt; and &lt;a href="https://melon-nerf.github.io/"&gt;MELON site&lt;/a&gt; to learn more.
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;
&lt;p&gt;
&lt;em&gt;We would like to thank our paper co-authors Axel Levy, Matan Sela, and Gordon Wetzstein, as well as Florian Schroff and Hartwig Adam for continuous help in building this technology. We also thank Matthew Brown, Ricardo Martin-Brualla and Frederic Poitevin for their helpful feedback on the paper draft. We also acknowledge the use of the computational resources at the SLAC Shared Scientific Data Facility (SDF).&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/757976001746578714/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/melon-reconstructing-3d-objects-from.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/757976001746578714" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/757976001746578714" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/melon-reconstructing-3d-objects-from.html" rel="alternate" title="MELON: Reconstructing 3D objects from images with unknown poses" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh8LjCbKjfNXVUyCpGiZysx_pNF5BK8p5VBCJXXPaz_Bb75CW-33weoMh0YaNcn4AdmGN-Pufd_XlsRzo2MWZLQxqgtri7Nip9tXoGX0CritvRKF-63StOWxp_gVaY-MTnOk9IvJdVt_CczVR6Ip_R8Yv32MHTw2-FckCTF4UOFrgMyq3PCPCkZaZ-nyMcE/s72-c/MELON%20HERO.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-977556648557231190</id><published>2024-03-15T11:22:00.000-07:00</published><updated>2024-03-15T11:22:13.760-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Health"/><category scheme="http://www.blogger.com/atom/ns#" term="Responsible AI"/><title type="text">HEAL: A framework for health equity assessment of machine learning performance</title><content type="html">&lt;span class="byline-author"&gt;Posted by Mike Schaekermann, Research Scientist, Google Research, and Ivor Horn, Chief Health Equity Officer &amp;amp; Director, Google Core&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYi3V0CsXup8WA6SSjPagoMWfkIpbr9oRWEaUM1vIWOX8_TsZs6ikqOn6qIGbqUzAPhOxwhEPNfWSkECIxRz5fJ629cRGScLraFn2CSw53Sr5_li8Fe7A9I1nMShys_15IiUZNhNiPh_ueFVcu_7f34A-A0pMXXVdDaSoSAf2h0jETJ1PemIR5I6o9pIIW/s1600/HEAL-Hero.png" style="display: none;" /&gt;

&lt;p&gt;
Health equity is a major societal concern worldwide with disparities having many causes. These sources include limitations in access to healthcare, differences in clinical treatment, and even fundamental differences in the diagnostic technology. In dermatology for example, skin cancer outcomes are worse for populations such as minorities, those with lower socioeconomic status, or individuals with limited healthcare access. While there is great promise in recent advances in machine learning (ML) and artificial intelligence (AI) to help improve healthcare, this transition from research to bedside must be accompanied by a careful understanding of whether and how they impact health equity.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt; 

&lt;p&gt;
&lt;em&gt;Health equity&lt;/em&gt; is defined by public health organizations as fairness of opportunity for everyone to be as healthy as possible. Importantly, equity may be different from &lt;em&gt;equality&lt;/em&gt;. For example, people with greater barriers to improving their health may require more or different effort to experience this fair opportunity. Similarly, equity is not &lt;em&gt;fairness&lt;/em&gt; as defined in the AI for healthcare literature. Whereas AI fairness often strives for equal performance of the AI technology across different patient populations, this does not center the goal of prioritizing performance with respect to pre-existing health disparities.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi21VRS33NG-Imj1XlKXWtrwUrl4loEEywV0tO8M0JWtUFFksbTLOhilTZtMdJTgOBdXACUPQX-f5TMAFkABFhdv_cEDmFn4d-JirU78covJI32sHus6XQVJ1C1elwM_MExsQfeVCpFYlq9QZeynLNpLqmW8GqM-DKWiGSyi_18n8Xb3-8IeepHSyBZ6_2l/s1999/image2.jpg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1999" data-original-width="1609" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi21VRS33NG-Imj1XlKXWtrwUrl4loEEywV0tO8M0JWtUFFksbTLOhilTZtMdJTgOBdXACUPQX-f5TMAFkABFhdv_cEDmFn4d-JirU78covJI32sHus6XQVJ1C1elwM_MExsQfeVCpFYlq9QZeynLNpLqmW8GqM-DKWiGSyi_18n8Xb3-8IeepHSyBZ6_2l/s16000/image2.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Health equity considerations. An intervention (e.g., an ML-based tool, indicated in dark blue) promotes health equity if it helps reduce existing disparities in health outcomes (indicated in lighter blue).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
In “&lt;a href="https://www.thelancet.com/journals/eclinm/article/PIIS2589-5370(24)00058-0/fulltext"&gt;Health Equity Assessment of machine Learning performance (HEAL): a framework and dermatology AI model case study&lt;/a&gt;”, published in &lt;a href="https://www.thelancet.com/journals/eclinm/home"&gt;&lt;i&gt;The Lancet eClinicalMedicine&lt;/i&gt;&lt;/a&gt;, we propose a methodology to quantitatively assess whether ML-based health technologies perform equitably. In other words, does the ML model perform well for those with the worst health outcomes for the condition(s) the model is meant to address? This goal anchors on the principle that health equity should prioritize and measure model performance with respect to disparate health outcomes, which may be due to a number of factors that include structural inequities (e.g., demographic, social, cultural, political, economic, environmental and geographic).
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;The health equity framework (HEAL)&lt;/h2&gt;

&lt;p&gt;
The HEAL framework proposes a 4-step process to estimate the likelihood that an ML-based health technology performs equitably:
&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
Identify factors associated with health inequities and define tool performance metrics,
&lt;/li&gt;
&lt;li&gt;
Identify and quantify pre-existing health disparities,
&lt;/li&gt;
&lt;li&gt;
Measure the performance of the tool for each subpopulation,
&lt;/li&gt;
&lt;li&gt;
Measure the likelihood that the tool prioritizes performance with respect to health disparities.
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;
The final step’s output is termed the HEAL metric, which quantifies how anticorrelated the ML model’s performance is with health disparities. In other words, does the model perform better with populations that have the worse health outcomes?
&lt;/p&gt;
&lt;p&gt;
This 4-step process is designed to inform improvements for making ML model performance more equitable, and is meant to be iterative and re-evaluated on a regular basis. For example, the availability of health outcomes data in step (2) can inform the choice of demographic factors and brackets in step (1), and the framework can be applied again with new datasets, models and populations.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjoGLCxn9QWS5QQpW39mJH1A_pw9wniWKIGGapN_gBC5WdxAWo4jHRS29GhNq7XBgNdZ867tMdP7TcszMz2WxUR4sYBFz0-dJ4cQZCODN2YFRjCP14QhNh_kMVGUdklbToOCYwHXV-UofhZdwZzDZudaVedOqvcC-QbW3LtMGb04FwFclbfzKHVUcqHodW_/s1999/image1.jpg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1352" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjoGLCxn9QWS5QQpW39mJH1A_pw9wniWKIGGapN_gBC5WdxAWo4jHRS29GhNq7XBgNdZ867tMdP7TcszMz2WxUR4sYBFz0-dJ4cQZCODN2YFRjCP14QhNh_kMVGUdklbToOCYwHXV-UofhZdwZzDZudaVedOqvcC-QbW3LtMGb04FwFclbfzKHVUcqHodW_/s16000/image1.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Framework for Health Equity Assessment of machine Learning performance (HEAL).&amp;nbsp;Our guiding principle is to avoid exacerbating health inequities, and these steps help us identify disparities and assess for inequitable model performance to move towards better outcomes for all.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
With this work, we take a step towards encouraging explicit assessment of the health equity considerations of AI technologies, and encourage prioritization of efforts during model development to reduce health inequities for subpopulations exposed to structural inequities that can precipitate disparate outcomes. We should note that the present framework does not model causal relationships and, therefore, cannot quantify the actual impact a new technology will have on reducing health outcome disparities. However, the HEAL metric may help identify opportunities for improvement, where the current performance is not prioritized with respect to pre-existing health disparities.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Case study on a dermatology model&lt;/h2&gt;


&lt;p&gt;
As an illustrative case study, we applied the framework to a dermatology model, which utilizes a convolutional neural network similar to that described in &lt;a href="https://blog.research.google/2019/09/using-deep-learning-to-inform.html"&gt;prior work&lt;/a&gt;. This example dermatology model was trained to classify 288 skin conditions using a development dataset of 29k cases. The input to the model consists of three photos of a skin concern along with demographic information and a brief structured medical history. The output consists of a ranked list of possible matching skin conditions. 
&lt;/p&gt;
&lt;p&gt;
Using the HEAL framework, we evaluated this model by assessing whether it prioritized performance with respect to pre-existing health outcomes. The model was designed to predict possible dermatologic conditions (from a list of hundreds) based on photos of a skin concern and patient metadata. Evaluation of the model is done using a top-3 agreement metric, which quantifies how often the top 3 output conditions match the most likely condition as suggested by a dermatologist panel. The HEAL metric is computed via the anticorrelation of this top-3 agreement with health outcome rankings. 
&lt;/p&gt;
&lt;p&gt;
We used a dataset of 5,420 teledermatology cases, enriched for diversity in age, sex and race/ethnicity, to retrospectively evaluate the model’s HEAL metric. The dataset consisted of “store-and-forward” cases from patients of 20 years or older from primary care providers in the USA and skin cancer clinics in Australia. Based on a review of the literature, we decided to explore race/ethnicity, sex and age as potential factors of inequity, and used sampling techniques to ensure that our evaluation dataset had sufficient representation of all race/ethnicity, sex and age groups. To quantify pre-existing health outcomes for each subgroup we relied on measurements from &lt;a href="https://www.who.int/data/gho/data/themes/mortality-and-global-health-estimates/global-health-estimates-leading-causes-of-dalys"&gt;public&lt;/a&gt; &lt;a href="https://www.thelancet.com/journals/lancet/article/PIIS0140-6736(20)30925-9/fulltext"&gt;databases&lt;/a&gt; endorsed by the World Health Organization, such as &lt;a href="https://www.who.int/data/gho/indicator-metadata-registry/imr-details/4427"&gt;Years of Life Lost&lt;/a&gt; (YLLs) and &lt;a href="https://www.who.int/data/gho/indicator-metadata-registry/imr-details/158"&gt;Disability-Adjusted Life Years&lt;/a&gt; (DALYs; years of life lost plus years lived with disability).
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiSS4J8AzS5iaHYvB7RyUVEDkx1ykrC7zOEAbUvjb8ZybZRZ0C71fRlJjPYBzGYVu9D3Ok0zRdz4MUdHMX6rOqnYKoHv91QNPw0TiqHJ6MKjtgn_UIqW-xoZeihO-A-ZrPgWT8bs-t9bSZWmMQ9AJaQh85BZWHH-T0KPWMx2unNO9HpTzYXiD_24gwNYWot/s1511/Table1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="602" data-original-width="1511" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiSS4J8AzS5iaHYvB7RyUVEDkx1ykrC7zOEAbUvjb8ZybZRZ0C71fRlJjPYBzGYVu9D3Ok0zRdz4MUdHMX6rOqnYKoHv91QNPw0TiqHJ6MKjtgn_UIqW-xoZeihO-A-ZrPgWT8bs-t9bSZWmMQ9AJaQh85BZWHH-T0KPWMx2unNO9HpTzYXiD_24gwNYWot/s16000/Table1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;HEAL metric for all dermatologic conditions across race/ethnicity subpopulations, including health outcomes (YLLs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance.&lt;br /&gt;(* Higher is better; measures the likelihood the model performs equitably with respect to the axes in this table.)&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMAQjyuGMXvzq4FxZg5Vhlgozwwnzza-QS-mjr3i0oOnDFIeqUGTrPxX2c7ssbpCZtLUoT2lpr8bXg_nJ3ToaaVe6Grge-HcWQl8SFy1gaBCoT-6ZHtFmQV4_S2sA6eOsdMFryegLjZFwOcPiqZDfFFItxqS96ysTZZn1OXVcbQSOG5WazZGjxSkNt9JQK/s1518/Table2.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="316" data-original-width="1518" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMAQjyuGMXvzq4FxZg5Vhlgozwwnzza-QS-mjr3i0oOnDFIeqUGTrPxX2c7ssbpCZtLUoT2lpr8bXg_nJ3ToaaVe6Grge-HcWQl8SFy1gaBCoT-6ZHtFmQV4_S2sA6eOsdMFryegLjZFwOcPiqZDfFFItxqS96ysTZZn1OXVcbQSOG5WazZGjxSkNt9JQK/s16000/Table2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;HEAL metric for all dermatologic conditions across sexes, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.)&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table

&lt;p&gt;
Our analysis estimated that the model was 80.5% likely to perform equitably across race/ethnicity subgroups and 92.1% likely to perform equitably across sexes.
&lt;/p&gt;
&lt;p&gt;
However, while the model was likely to perform equitably across age groups for cancer conditions specifically, we discovered that it had room for improvement across age groups for non-cancer conditions. For example, those 70+ have the poorest health outcomes related to non-cancer skin conditions, yet the model didn't prioritize performance for this subgroup.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh4s5yfNQCksLIqP3kYuDXahUlOcJSCEtt-JkSTsecDft21uJ8JR0imnsPVGYHVQnc7OPo1WOkcwx2Yevu6su-rbqc1Fl6_NfzCKl0_vOvZA3PPnLkVWKFk7jHPJCm-x69MupVih_zct1YOXJVvSNUIsvn4rICk-_RWbOeuKj4HdRphBOakRXsiJ4lETJ_M/s1508/Table3.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="644" data-original-width="1508" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh4s5yfNQCksLIqP3kYuDXahUlOcJSCEtt-JkSTsecDft21uJ8JR0imnsPVGYHVQnc7OPo1WOkcwx2Yevu6su-rbqc1Fl6_NfzCKl0_vOvZA3PPnLkVWKFk7jHPJCm-x69MupVih_zct1YOXJVvSNUIsvn4rICk-_RWbOeuKj4HdRphBOakRXsiJ4lETJ_M/s16000/Table3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;HEAL metrics for all cancer and non-cancer dermatologic conditions across age groups, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.)&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt; 

&lt;h2&gt;Putting things in context&lt;/h2&gt;


&lt;p&gt;
For holistic evaluation, the HEAL metric cannot be employed in isolation. Instead this metric should be contextualized alongside many other factors ranging from computational efficiency and data privacy to ethical values, and aspects that may influence the results (e.g., selection bias or differences in representativeness of the evaluation data across demographic groups). 
&lt;/p&gt;
&lt;p&gt;
As an adversarial example, the HEAL metric can be artificially improved by deliberately reducing model performance for the most advantaged subpopulation until performance for that subpopulation is worse than all others. For illustrative purposes, given subpopulations A and B where A has worse health outcomes than B, consider the choice between two models: Model 1 (M1) performs 5% better for subpopulation A than for subpopulation B. Model 2 (M2) performs 5% worse on subpopulation A than B. The HEAL metric would be higher for M1 because it prioritizes performance on a subpopulation with worse outcomes. However, M1 may have absolute performances of just 75% and 70% for subpopulations A and B respectively, while M2 has absolute performances of 75% and 80% for subpopulations A and B respectively. Choosing M1 over M2 would lead to worse overall performance for all subpopulations because some subpopulations are worse-off while no subpopulation is better-off. 
&lt;/p&gt;
&lt;p&gt;
Accordingly, the HEAL metric should be used alongside a &lt;a href="https://en.wikipedia.org/wiki/Pareto_efficiency"&gt;Pareto condition&lt;/a&gt; (discussed further in the paper), which restricts model changes so that outcomes for each subpopulation are either unchanged or improved compared to the status quo, and performance does not worsen for any subpopulation.
&lt;/p&gt;
&lt;p&gt;
The HEAL framework, in its current form, assesses the likelihood that an ML-based model prioritizes performance for subpopulations with respect to pre-existing health disparities for specific subpopulations. This differs from the goal of understanding whether ML will reduce disparities in outcomes across subpopulations in reality. Specifically, modeling improvements in outcomes requires a causal understanding of steps in the care journey that happen both before and after use of any given model. Future research is needed to address this gap.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
The HEAL framework enables a quantitative assessment of the likelihood that health AI technologies prioritize performance with respect to health disparities. The case study demonstrates how to apply the framework in the dermatological domain, indicating a high likelihood that model performance is prioritized with respect to health disparities across sex and race/ethnicity, but also revealing the potential for improvements for non-cancer conditions across age. The case study also illustrates limitations in the ability to apply all recommended aspects of the framework (e.g., mapping societal context, availability of data), thus highlighting the complexity of health equity considerations of ML-based tools. 
&lt;/p&gt;
&lt;p&gt;
This work is a proposed approach to address a grand challenge for AI and health equity, and may provide a useful evaluation framework not only during model development, but during pre-implementation and real-world monitoring stages, e.g., in the form of health equity dashboards. We hold that the strength of the HEAL framework is in its future application to various AI tools and use cases and its refinement in the process. Finally, we acknowledge that a successful approach towards understanding the impact of AI technologies on health equity needs to be more than a set of metrics. It will require a set of goals agreed upon by a community that represents those who will be most impacted by a model.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;The research described here is joint work across many teams at Google. We are grateful to all our co-authors: Terry Spitz, Malcolm Pyles, Heather Cole-Lewis, Ellery Wulczyn, Stephen R. Pfohl, Donald Martin, Jr., Ronnachai Jaroensri, Geoff Keeling, Yuan Liu, Stephanie Farquhar, Qinghan Xue, Jenna Lester, Cían Hughes, Patricia Strachan, Fraser Tan, Peggy Bui, Craig H. Mermel, Lily H. Peng, Yossi Matias, Greg S. Corrado, Dale R. Webster, Sunny Virmani, Christopher Semturs, Yun Liu, and Po-Hsuan Cameron Chen. We also thank Lauren Winer, Sami Lachgar, Ting-An Lin, Aaron Loh, Morgan Du, Jenny Rizk, Renee Wong, Ashley Carrick, Preeti Singh, Annisah Um'rani, Jessica Schrouff, Alexander Brown, and Anna Iurchenko for their support of this project.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/977556648557231190/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/heal-framework-for-health-equity.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/977556648557231190" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/977556648557231190" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/heal-framework-for-health-equity.html" rel="alternate" title="HEAL: A framework for health equity assessment of machine learning performance" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYi3V0CsXup8WA6SSjPagoMWfkIpbr9oRWEaUM1vIWOX8_TsZs6ikqOn6qIGbqUzAPhOxwhEPNfWSkECIxRz5fJ629cRGScLraFn2CSw53Sr5_li8Fe7A9I1nMShys_15IiUZNhNiPh_ueFVcu_7f34A-A0pMXXVdDaSoSAf2h0jETJ1PemIR5I6o9pIIW/s72-c/HEAL-Hero.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-7868032799856333119</id><published>2024-03-14T12:38:00.000-07:00</published><updated>2024-03-14T12:38:11.597-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Large Language Models"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="Natural Language Processing"/><category scheme="http://www.blogger.com/atom/ns#" term="NeurIPS"/><title type="text">Cappy: Outperforming and boosting large multi-task language models with a small scorer</title><content type="html">&lt;span class="byline-author"&gt;Posted by Yun Zhu and Lijuan Liu, Software Engineers, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiFNlqVAnwoYdZ97LvC4-ipR6FeOc4o9udsTUtNBBWl5Y4XHclcrz3kTCibizteSBc_xsVLh-pyRiCCNfIzTDHEs7VsJcUMCk0EjUxzvKITKCncdx1y7u9JXGkXM6TyoZY5RhUt2l_up-Us0yIV-0-EUvHsjOlFNSSNgNHlpwK1PAliqcj4gSoLsYXhIi18/s320/Cappy%20hero.jpg" style="display: none;" /&gt;


&lt;p&gt;
Large language model (LLM) advancements have led to a new paradigm that unifies various natural language processing (NLP) tasks within an instruction-following framework. This paradigm is exemplified by recent multi-task LLMs, such as &lt;a href="https://arxiv.org/abs/2110.08207"&gt;T0&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2210.11416"&gt;FLAN&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2212.12017"&gt;OPT-IML&lt;/a&gt;. First, multi-task data is gathered with each task following a task-specific template, where each labeled example is converted into an instruction (e.g., &lt;em&gt;"&lt;/em&gt;Put the concepts together to form a sentence: ski, mountain, skier&lt;em&gt;”&lt;/em&gt;) paired with a corresponding response (e.g., &lt;em&gt;"&lt;/em&gt;Skier skis down the mountain&lt;em&gt;"&lt;/em&gt;). These instruction-response pairs are used to train the LLM, resulting in a conditional generation model that takes an instruction as input and generates a response. Moreover, multi-task LLMs have exhibited remarkable task-wise generalization capabilities as they can address unseen tasks by understanding and solving brand-new instructions.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhMcacnhPA68XiEskvhExF4SGFh4997UZzwvhYfXt-ReGXtzfGTamLB3LZoYSh8WWuf1dmlBnNAUecAMhrBTOMVF6vxsw3BqY8Ld5xPgSdZY_cywScxxxQ5e6uwhawA5VYDEj6VtSyOTNGZtjdLXieeFV5OLiDk3bnB-xaz4MIbvUO-7RPadk8iQDv3206V/s640/Cappy%20instruction-following.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="177" data-original-width="640" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhMcacnhPA68XiEskvhExF4SGFh4997UZzwvhYfXt-ReGXtzfGTamLB3LZoYSh8WWuf1dmlBnNAUecAMhrBTOMVF6vxsw3BqY8Ld5xPgSdZY_cywScxxxQ5e6uwhawA5VYDEj6VtSyOTNGZtjdLXieeFV5OLiDk3bnB-xaz4MIbvUO-7RPadk8iQDv3206V/s16000/Cappy%20instruction-following.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;The demonstration of the instruction-following pre-training of multi-task LLMs, e.g., FLAN. Pre-training tasks under this paradigm improves the performance for unseen tasks.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
Due to the complexity of understanding and solving various tasks solely using instructions, the size of multi-task LLMs typically spans from several billion parameters to hundreds of billions (e.g., &lt;a href="https://arxiv.org/abs/2210.11416"&gt;FLAN-11B&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2110.08207"&gt;T0-11B&lt;/a&gt; and &lt;a href="https://arxiv.org/abs/2212.12017"&gt;OPT-IML-175B&lt;/a&gt;). As a result, operating such sizable models poses significant challenges because they demand considerable computational power and impose substantial requirements on the memory capacities of GPUs and TPUs, making their training and inference expensive and inefficient. Extensive storage is required to maintain a unique LLM copy for each downstream task. Moreover, the most powerful multi-task LLMs (e.g., FLAN-PaLM-540B) are closed-sourced, making them impossible to be adapted. However, in practical applications, harnessing a single multi-task LLM to manage all conceivable tasks in a zero-shot manner remains difficult, particularly when dealing with complex tasks, personalized tasks and those that cannot be succinctly defined using instructions. On the other hand, the size of downstream training data is usually insufficient to train a model well without incorporating rich prior knowledge. Hence, it is long desired to adapt LLMs with downstream supervision while bypassing storage, memory, and access issues. 
&lt;/p&gt;

&lt;p&gt;
Certain &lt;em&gt;parameter-efficient tuning&lt;/em&gt; strategies, including &lt;a href="https://aclanthology.org/2021.acl-long.353.pdf"&gt;prompt tuning&lt;/a&gt; and &lt;a href="https://openreview.net/pdf?id=nZeVKeeFYf9"&gt;adapters&lt;/a&gt;, substantially diminish storage requirements, but they still perform back-propagation through LLM parameters during the tuning process, thereby keeping their memory demands high. Additionally, some &lt;em&gt;&lt;a href="https://arxiv.org/pdf/2301.00234.pdf"&gt;in-context learning&lt;/a&gt;&lt;/em&gt; techniques circumvent parameter tuning by integrating a limited number of supervised examples into the instruction. However, these techniques are constrained by the model's maximum input length, which permits only a few samples to guide task resolution.
&lt;/p&gt;

&lt;p&gt;
In “&lt;a href="https://arxiv.org/abs/2311.06720"&gt;Cappy: Outperforming and Boosting Large Multi-Task LMs with a Small Scorer&lt;/a&gt;”, presented at &lt;a href="https://nips.cc/virtual/2023/index.html"&gt;NeurIPS 2023&lt;/a&gt;, we propose a novel approach that enhances the performance and efficiency of multi-task LLMs. We introduce a lightweight pre-trained scorer, Cappy, based on continual pre-training on top of &lt;a href="https://arxiv.org/abs/1907.11692"&gt;RoBERTa&lt;/a&gt; with merely 360 million parameters. Cappy takes in an instruction and a candidate response as input, and produces a score between 0 and 1, indicating an estimated correctness of the response with respect to the instruction. Cappy functions either independently on classification tasks or serves as an auxiliary component for LLMs, boosting their performance. Moreover, Cappy efficiently enables downstream supervision without requiring any finetuning, which avoids the need for back-propagation through LLM parameters and reduces memory requirements. Finally, adaptation with Cappy doesn’t require access to LLM parameters as it is compatible with closed-source multi-task LLMs, such as those only accessible via WebAPIs.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgKxfEf2em6vxULs9wHGOB2jU7AiGMhUiEsJENdGWjB-8AMW6T2uRUrp3k3776491wzNsQCEk2T26AmiPNaKi-mfiIRNHe7JKZuR4ETQbHrM5h1knDNDBZ-qPw6sPGhtA4v0dz9YtKbHyoXPWEgYkY6r-tv8brepN8_Qq7MjCIwGUaYw5LmJMY4KLxu28ku/s1999/Cappy%20overview.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="975" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgKxfEf2em6vxULs9wHGOB2jU7AiGMhUiEsJENdGWjB-8AMW6T2uRUrp3k3776491wzNsQCEk2T26AmiPNaKi-mfiIRNHe7JKZuR4ETQbHrM5h1knDNDBZ-qPw6sPGhtA4v0dz9YtKbHyoXPWEgYkY6r-tv8brepN8_Qq7MjCIwGUaYw5LmJMY4KLxu28ku/s16000/Cappy%20overview.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Cappy takes an instruction and response pair as input and outputs a score ranging from 0 to 1, indicating an estimation of the correctness of the response with respect to the instruction.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Pre-training&lt;/h2&gt;


&lt;p&gt;
We begin with the same dataset collection, which includes 39 diverse datasets from &lt;a href="https://arxiv.org/abs/2202.01279"&gt;PromptSource&lt;/a&gt; that were used to train &lt;a href="https://arxiv.org/abs/2110.08207"&gt;T0&lt;/a&gt;. This collection encompasses a wide range of task types, such as question answering, sentiment analysis, and summarization. Each dataset is associated with one or more templates that convert each instance from the original datasets into an instruction paired with its ground truth response.
&lt;/p&gt;

&lt;p&gt;
Cappy's regression modeling requires each pre-training data instance to include an instruction-response pair along with a correctness annotation for the response, so we produce a dataset with correctness annotations that range from 0 to 1. For every instance within a generation task, we leverage an existing multi-task LLM to generate multiple responses by sampling, conditioned on the given instruction. Subsequently, we assign an annotation to the pair formed by the instruction and every response, using the similarity between the response and the ground truth response of the instance. Specifically, we employ &lt;a href="https://aclanthology.org/W04-1013/"&gt;Rouge-L&lt;/a&gt;, a commonly-used metric for measuring overall multi-task performance that has demonstrated a strong alignment with human evaluation, to calculate this similarity as a form of weak supervision.
&lt;/p&gt;

&lt;p&gt;
As a result, we obtain an effective regression dataset of 160 million instances paired with correctness score annotations. The final Cappy model is the result of continuous pre-training using the regression dataset on top of the &lt;a href="https://arxiv.org/abs/1907.11692"&gt;RoBERTa&lt;/a&gt; model. The pre-training of Cappy is conducted on Google's &lt;a href="https://arxiv.org/abs/2304.01433"&gt;TPU-v4&lt;/a&gt;, with &lt;a href="https://arxiv.org/pdf/2310.16355.pdf"&gt;RedCoast&lt;/a&gt;, a lightweight toolkit for automating distributed training.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEguKQnabejBzwCo7XEYZBJaaHi9_Z0Z03aofMxhmno2dKMbh2d6qVhmu7kKLN7FVExLXwYZYu1UEa1brRSC7bX3ASLyZymVyougwQqhCoE7Iio6DvIzdIK_dYT-1IGk41jZ6qdYcDynxezST6FY8u73opddwlGcGTf-3fXY4KfPo5hhfIinUl7iXRN7V6Sr/s1999/Cappy%20data%20augmentation.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="438" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEguKQnabejBzwCo7XEYZBJaaHi9_Z0Z03aofMxhmno2dKMbh2d6qVhmu7kKLN7FVExLXwYZYu1UEa1brRSC7bX3ASLyZymVyougwQqhCoE7Iio6DvIzdIK_dYT-1IGk41jZ6qdYcDynxezST6FY8u73opddwlGcGTf-3fXY4KfPo5hhfIinUl7iXRN7V6Sr/s16000/Cappy%20data%20augmentation.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Data augmentation with a multi-task LLM to construct a weakly supervised regression dataset for Cappy’s pre-training and fine-tuning.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Applying Cappy&lt;/h2&gt;


&lt;p&gt;
Cappy solves practical tasks within a candidate-selection mechanism. More specifically, given an instruction and a set of candidate responses, Cappy produces a score for each candidate response. This is achieved by inputting the instruction alongside each individual response, and then assigning the response with the highest score as its prediction. In classification tasks, all candidate responses are inherently predefined. For example, for an instruction of a sentiment classification task (e.g., “Based on this review, would the user recommend this product?: ‘Stunning even for the non-gamer.’”), the candidate responses are “Yes” or “No”. In such scenarios, Cappy functions independently. On the other hand, in generation tasks, candidate responses are not pre-defined, requiring an existing multi-task LLM to yield the candidate responses. In this case, Cappy serves as an auxiliary component of the multi-task LLM, enhancing its decoding.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Adapting multi-task LLMs with Cappy &lt;/h3&gt;


&lt;p&gt;
When there is available downstream training data, Cappy enables effective and efficient adaptation of multi-task LLMs on downstream tasks. Specifically, we fine-tune Cappy to integrate downstream task information into LLM predictions. This process involves creating a separate regression dataset specific to the downstream training data with the same data annotation process used to construct the pre-training data. As a result, the fine-tuned Cappy collaborates with a multi-task LLM, boosting the LLM's performance on the downstream task.
&lt;/p&gt;

&lt;p&gt;
In contrast to other LLM tuning strategies, adapting LLMs with Cappy significantly reduces the high demand for device memory as it avoids the need for back-propagation through LLM parameters for downstream tasks.  Moreover, Cappy adaptation does not rely on the access to LLM parameters, making it compatible with closed-source multi-task LLMs, such as the ones only accessible via WebAPIs. Compared with in-context learning approaches, which circumvent model tuning by attaching training examples to the instruction prefix, Cappy is not restricted by the LLM's maximum input length. Thus, Cappy can incorporate an unlimited number of downstream training examples. Cappy can also be applied with other adaptation methods, such as fine-tuning and in-context learning, further boosting their overall performance.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhf1zSOPmuCuPWHPOXYRczk86xESIKANYJN7jUqjkoSQabuQrDyLEfyLCXG0eAEHG1xiYL6jrZ8iMC14a2FhQs7XNwyncRdCyfIRa3KlLx3786yfSXfP9pEwtUEJ6ax7l5J8MchxjH9cV_hKqQFanTh3kNCs_JHYw0vsMOFi09-69-anFrqJShRgYFcKvfe/s1999/Cappy%20downstream%20adaptation.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1040" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhf1zSOPmuCuPWHPOXYRczk86xESIKANYJN7jUqjkoSQabuQrDyLEfyLCXG0eAEHG1xiYL6jrZ8iMC14a2FhQs7XNwyncRdCyfIRa3KlLx3786yfSXfP9pEwtUEJ6ax7l5J8MchxjH9cV_hKqQFanTh3kNCs_JHYw0vsMOFi09-69-anFrqJShRgYFcKvfe/s16000/Cappy%20downstream%20adaptation.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Downstream adaptation comparison between Cappy and approaches that rely on an LLM’s parameters, such as fine-tuning and prompt tuning. Cappy’s application enhances multi-task LLMs.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Results&lt;/h2&gt;
&lt;p&gt;
We assess Cappy’s performance across eleven held-out language understanding classification tasks from &lt;a href="https://arxiv.org/abs/2202.01279"&gt;PromptSource&lt;/a&gt;. We demonstrate that Cappy, with 360M parameters, outperforms OPT-175B and OPT-IML-30B, and matches the accuracy of  the best existing multi-task LLMs (T0-11B and OPT-IML-175B). These findings highlight Cappy’s capabilities and parameter efficiency, which can be credited to its scoring-based pre-training strategy that integrates contrastive information by differentiating between high-quality and low-quality responses. On the contrary, previous multi-task LLMs depend exclusively on &lt;a href="https://en.wikipedia.org/wiki/Teacher_forcing"&gt;teacher-forcing training&lt;/a&gt; that utilizes only the ground truth responses.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjyehdahD05Plit772klEfeGTN1GteCZcwwsyWbGgOTtgH4VD3hzPkF8PSDdYZe2EOE0nwL9xdNLZYLzvBJrm9ECTSGIWWUJ-Xo-1uVQUmN8uu0_5dLAERYPvOFfahf1ZZ2bId0tna1ch8BBXV9xKWpPKNIoAlihdNxZvlegShjI6Fjd5Twd8kv6w-axtUW/s1999/Cappy%20accuracy.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1125" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjyehdahD05Plit772klEfeGTN1GteCZcwwsyWbGgOTtgH4VD3hzPkF8PSDdYZe2EOE0nwL9xdNLZYLzvBJrm9ECTSGIWWUJ-Xo-1uVQUmN8uu0_5dLAERYPvOFfahf1ZZ2bId0tna1ch8BBXV9xKWpPKNIoAlihdNxZvlegShjI6Fjd5Twd8kv6w-axtUW/s16000/Cappy%20accuracy.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;The overall accuracy averaged over eleven test tasks from PromptSource. “RM” refers to a &lt;a href="https://huggingface.co/OpenAssistant/reward-model-deberta-v3-large-v2"&gt;pre-trained RLHF reward model&lt;/a&gt;. Cappy matches the best ones among existing multi-task LLMs.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
We also examine the adaptation of multi-task LLMs with Cappy on complex tasks from &lt;a href="https://arxiv.org/abs/2206.04615"&gt;BIG-Bench&lt;/a&gt;, a set of manually curated tasks that are considered beyond the capability of many LLMs. We focus on all the 45 generation BIG-Bench tasks, specifically those that do not offer pre-established answer choices. We evaluate the performance using the Rouge-L score (representing the overall similarity between model generations and corresponding ground truths) on every test set, reporting the average score across 45 tests. In this experiment, all variants of FLAN-T5 serve as the backbone LLMs, and the foundational FLAN-T5 models are frozen. These results, shown below, suggest that Cappy enhances the performance of FLAN-T5 models by a large margin, consistently outperforming the most effective baseline achieved through sample selection using self-scoring of the LLM itself.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhUmqWX5mq_zgs3nK6TSR3sNEunAburBnwxpIaFNxTuXbhLKeuI-c71IBxZw3tEnnnOHeE7heImqnZyluCAV92_2fhhXEfus_4R0MC78e_WOOXcSNvfyiVLNqNGhYK88YfiT__Ijss-OPpCo4XDz4vLFjtJKM-Mko_n2IgMabNI5J1a3LAVlIvBvRpiZ8GZ/s1999/Cappy%20averaged%20Rouge-L%20score.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1625" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhUmqWX5mq_zgs3nK6TSR3sNEunAburBnwxpIaFNxTuXbhLKeuI-c71IBxZw3tEnnnOHeE7heImqnZyluCAV92_2fhhXEfus_4R0MC78e_WOOXcSNvfyiVLNqNGhYK88YfiT__Ijss-OPpCo4XDz4vLFjtJKM-Mko_n2IgMabNI5J1a3LAVlIvBvRpiZ8GZ/s16000/Cappy%20averaged%20Rouge-L%20score.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;The averaged Rouge-L score over 45 complex tasks within BIG-Bench. The x-axis refers to FLAN-T5 models of different sizes. Every dashed line represents an approach working on FLAN-T5s. Self-scoring refers to using the cross-entropy of LLM to select responses. Cappy enhances the performance of FLAN-T5 models by a large margin.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;
We introduce Cappy, a novel approach that enhances the performance and efficiency of multi-task LLMs. In our experiments, we adapt a single LLM to several domains with Cappy. In the future, Cappy as a pre-trained model can potentially be used in other creative ways beyond on single LLMs.
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgments&lt;/h2&gt;
&lt;p&gt;
&lt;em&gt;Thanks to Bowen Tan, Jindong Chen, Lei Meng, Abhanshu Sharma and Ewa Dominowska for their valuable feedback. We would also like to thank Eric Xing and Zhiting Hu for their suggestions. &lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/7868032799856333119/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/cappy-outperforming-and-boosting-large.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/7868032799856333119" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/7868032799856333119" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/cappy-outperforming-and-boosting-large.html" rel="alternate" title="Cappy: Outperforming and boosting large multi-task language models with a small scorer" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiFNlqVAnwoYdZ97LvC4-ipR6FeOc4o9udsTUtNBBWl5Y4XHclcrz3kTCibizteSBc_xsVLh-pyRiCCNfIzTDHEs7VsJcUMCk0EjUxzvKITKCncdx1y7u9JXGkXM6TyoZY5RhUt2l_up-Us0yIV-0-EUvHsjOlFNSSNgNHlpwK1PAliqcj4gSoLsYXhIi18/s72-c/Cappy%20hero.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-6090481872694489715</id><published>2024-03-12T14:15:00.000-07:00</published><updated>2024-03-19T09:12:05.568-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Generative AI"/><category scheme="http://www.blogger.com/atom/ns#" term="Graphs"/><category scheme="http://www.blogger.com/atom/ns#" term="Large Language Models"/><title type="text">Talk like a graph: Encoding graphs for large language models</title><content type="html">&lt;span class="byline-author"&gt;Posted by Bahare Fatemi and Bryan Perozzi, Research Scientists, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg8L7r_SCzFsKsWegtrn8_EOoO2imefs-V_GVHzbM0Xw7GmAxoXIIX0RtpJ2JvloeenxcKCNmhCH_VXRMpu8b5dJP39UkhMJS0wP86TUftZtUi-hfj6tZdVEn30MZAeQEx762q1vN-q4DWP2EdOBIHy_CgNFMcliaJYnzxZHjnuifbVWy52zlls20m4BkyJ/s1600/Screenshot%202024-03-12%20at%202.18.27%E2%80%AFPM.png" style="display: none;" /&gt;

&lt;p&gt;
Imagine all the things around you — your friends, tools in your kitchen, or even the parts of your bike. They are all connected in different ways. In computer science, the term &lt;em&gt;&lt;a href="https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)"&gt;graph&lt;/a&gt; &lt;/em&gt;is used to describe connections between objects. Graphs consist of nodes (the objects themselves) and edges (connections between two nodes, indicating a  relationship between them). Graphs are everywhere now. The internet itself is a giant graph of websites linked together. Even the knowledge search engines use is organized in a graph-like way.
&lt;/p&gt;&lt;a name='more'&gt;&lt;/a&gt;



&lt;p&gt;
Furthermore, consider the remarkable advancements in artificial intelligence — such as chatbots that can write stories in seconds, and even software that can interpret medical reports. This exciting progress is largely thanks to large language models (LLMs). New LLM technology is constantly being developed for different uses. 
&lt;/p&gt;
&lt;p&gt;
Since graphs are everywhere and LLM technology is on the rise, in “&lt;a href="https://openreview.net/forum?id=IuXR1CCrSi"&gt;Talk like a Graph: Encoding Graphs for Large Language Models&lt;/a&gt;”, presented at &lt;a href="https://iclr.cc/"&gt;ICLR 2024&lt;/a&gt;, we present a way to teach powerful LLMs how to better reason with graph information. Graphs are a useful way to organize information, but LLMs are mostly trained on regular text. The objective is to test different techniques to see what works best and gain practical insights. Translating graphs into text that LLMs can understand is a remarkably complex task. The difficulty stems from the inherent complexity of graph structures with multiple nodes and the intricate web of edges that connect them. Our work studies how to take a graph and translate it into a format that an LLM can understand. We also design a benchmark called &lt;em&gt;&lt;a href="https://github.com/google-research/google-research/tree/master/graphqa"&gt;GraphQA&lt;/a&gt;&lt;/em&gt; to study different approaches on different graph reasoning problems and show how to &lt;em&gt;phrase&lt;/em&gt; a graph-related problem in a way that enables the LLM to solve the graph problem. We show that LLM performance on graph reasoning tasks varies on three fundamental levels: 1) the graph encoding method, 2) the nature of the graph task itself, and 3) interestingly, the very structure of the graph considered. These findings give us clues on how to best represent graphs for LLMs. Picking the right method can make the LLM up to 60% better at graph tasks!
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjAnieWluvqGQtyh_L3a_Y7XfYUR2dBRGpQf58DpzJfIrkyM2JnwxiCOvTzDidvP-GtbtRe4NsJUEFlzpW8nQbf8WGQD6P_C2jjsRZeLiyDSO8QF8IiGCRYnSa4MxruywJt60gU8KrH6w87ZoBXsGbPmyWDx01j1nqSCaEtfFeNTmAWSLcVVcND8XuzoaHb/s1600/image7.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="400" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjAnieWluvqGQtyh_L3a_Y7XfYUR2dBRGpQf58DpzJfIrkyM2JnwxiCOvTzDidvP-GtbtRe4NsJUEFlzpW8nQbf8WGQD6P_C2jjsRZeLiyDSO8QF8IiGCRYnSa4MxruywJt60gU8KrH6w87ZoBXsGbPmyWDx01j1nqSCaEtfFeNTmAWSLcVVcND8XuzoaHb/s16000/image7.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Pictured, the process of encoding a graph as text using two different approaches and feeding the text and a question about the graph to the LLM.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Graphs as text&lt;/h2&gt;


&lt;p&gt;
To be able to systematically find out what is the best way to translate a graph to text, we first design a benchmark called &lt;em&gt;&lt;a href="https://github.com/google-research/google-research/tree/master/graphqa"&gt;GraphQA&lt;/a&gt;&lt;/em&gt;. Think of GraphQA as an exam designed to evaluate powerful LLMs on graph-specific problems. We want to see how well LLMs can understand and solve problems that involve graphs in different setups. To create a comprehensive and realistic exam for LLMs, we don’t just use one type of graph, we use a mix of graphs ensuring breadth in the number of connections. This is mainly because different graph types make solving such problems easier or harder. This way, GraphQA can help expose biases in how an LLM thinks about the graphs, and the whole exam gets closer to a realistic setup that LLMs might encounter in the real world.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlHJKqbwgQJFMK4siQJH_Ggag9B8lStCQ4CcXk8iPnNPgxGPLYl_LTrIfjxuP7vKKtzJITlltZ5pcq7RElYNVQJ8PKi9Sr3ctigYfLs6SBlMAEhDHP2nV2PJ-uLhJxUkZ3MdAGV7R8rjw0u6Y8QTCwrMTyqz7tuxzb3TnIFabf4ZZbsSQ95MSboOA42i4w/s1368/image6.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="291" data-original-width="1368" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlHJKqbwgQJFMK4siQJH_Ggag9B8lStCQ4CcXk8iPnNPgxGPLYl_LTrIfjxuP7vKKtzJITlltZ5pcq7RElYNVQJ8PKi9Sr3ctigYfLs6SBlMAEhDHP2nV2PJ-uLhJxUkZ3MdAGV7R8rjw0u6Y8QTCwrMTyqz7tuxzb3TnIFabf4ZZbsSQ95MSboOA42i4w/s16000/image6.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Overview of our framework for reasoning with graphs using LLMs.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
GraphQA focuses on simple tasks related to graphs, like checking if an edge exists, calculating the number of nodes or edges, finding nodes that are connected to a specific node, and checking for cycles in a graph. These tasks might seem basic, but they require understanding the relationships between nodes and edges. By covering different types of challenges, from identifying patterns to creating new connections, GraphQA helps models learn how to analyze graphs effectively. These basic tasks are crucial for more complex reasoning on graphs, like finding the shortest path between nodes, detecting communities, or identifying influential nodes. Additionally, GraphQA includes generating random graphs using various algorithms like &lt;a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model"&gt;Erdős-Rényi&lt;/a&gt;, &lt;a href="https://en.wikipedia.org/wiki/Scale-free_network"&gt;scale-free networks&lt;/a&gt;, &lt;a href="https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model"&gt;Barabasi-Albert model&lt;/a&gt;, and &lt;a href="https://en.wikipedia.org/wiki/Stochastic_block_model"&gt;stochastic block model&lt;/a&gt;, as well as simpler graph structures like paths, complete graphs, and star graphs, providing a diverse set of data for training.
&lt;/p&gt;
&lt;p&gt;
When working with graphs, we also need to find ways to ask graph-related questions that LLMs can understand.  &lt;em&gt;Prompting heuristics&lt;/em&gt; are different strategies for doing this. Let's break down the common ones:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;&lt;em&gt;Zero-shot&lt;/em&gt;: simply describe the task ("Is there a cycle in this graph?") and tell the LLM to go for it. No examples provided.

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Few-shot&lt;/em&gt;: This is like giving the LLM a mini practice test before the real deal. We provide a few example graph questions and their correct answers.

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Chain-of-Thought&lt;/em&gt;: Here, we show the LLM how to break down a problem step-by-step with examples. The goal is to teach it to generate its own "thought process" when faced with new graphs.

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Zero-CoT&lt;/em&gt;: Similar to CoT, but instead of training examples, we give the LLM a simple prompt, like "Let's think step-by-step," to trigger its own problem-solving breakdown.

&lt;/li&gt;&lt;li&gt;&lt;em&gt;BAG (build a graph)&lt;/em&gt;: This is specifically for graph tasks. We add the phrase "Let's build a graph..." to the description, helping the LLM focus on the graph structure.
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
We explored different ways to translate graphs into text that LLMs can work with. Our key questions were:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;&lt;em&gt;Node encoding&lt;/em&gt;: How do we represent individual nodes? Options tested include simple &lt;a href="https://en.wikipedia.org/wiki/Integer"&gt;integers&lt;/a&gt;, common names (people, characters), and letters.

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Edge encoding&lt;/em&gt;: How do we describe the relationships between nodes? Methods involved parenthesis notation, phrases like "are friends", and symbolic representations like arrows.
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
Various node and edge encodings were combined systematically. This led to functions like the ones in the following figure:
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhHqSBznT1daPxVFMf0ZOR0uiZpjYrTG46t71FWy4tq5IMh-Ijhbzp_toJVmvp72FGrtoQXFkhCaaDVkhCzQXzcfRUPvW7151j22mmVxejpNJdO6VcvdHOkmEye_1zEBtfvAVgSw6RPFOiCpdo9LnetLvgrS-OL7IZPRLpBaCWGny_mzk6wpZcHDY-oS1ts/s855/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="404" data-original-width="855" height="302" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhHqSBznT1daPxVFMf0ZOR0uiZpjYrTG46t71FWy4tq5IMh-Ijhbzp_toJVmvp72FGrtoQXFkhCaaDVkhCzQXzcfRUPvW7151j22mmVxejpNJdO6VcvdHOkmEye_1zEBtfvAVgSw6RPFOiCpdo9LnetLvgrS-OL7IZPRLpBaCWGny_mzk6wpZcHDY-oS1ts/w640-h302/image1.png" width="640" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Examples of graph encoding functions used to encode graphs via text.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Analysis and results&lt;/h2&gt;


&lt;p&gt;
We carried out three key experiments: one to test how LLMs handle graph tasks, and two to understand how the size of the LLM and different graph shapes affected performance. We run all our experiments on GraphQA. 
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;How LLMs handle graph tasks &lt;/h3&gt;


&lt;p&gt;
In this experiment, we tested how well pre-trained LLMs tackle graph problems like identifying connections, cycles, and node degrees. Here is what we learned:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;&lt;em&gt;LLMs struggle:&lt;/em&gt; On most of these basic tasks, LLMs did not do much better than a random guess. 

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Encoding matters significantly&lt;/em&gt;: How we represent the graph as text has a great effect on LLM performance. The "incident" encoding excelled for most of the tasks in general.
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
Our results are summarized in the following chart. 
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYJLoJxI1twg6uV55JaKbVVnhO-dhgcaSN_B-FK9MTT8kKI1k_xnbGCvaEpmr82U4OGQxJ-oGNYOa0izo3jD1Ssvz8BVaKgw5ObjwN6_zS54BOALM_aO6TbLf-7SfcokAqRRC9fUbdErDeuadKBuRq7ihEootiLodZoYLKtZVDAgTI1ZrxviY7SI1PcFm5/s1864/image8.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1152" data-original-width="1864" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYJLoJxI1twg6uV55JaKbVVnhO-dhgcaSN_B-FK9MTT8kKI1k_xnbGCvaEpmr82U4OGQxJ-oGNYOa0izo3jD1Ssvz8BVaKgw5ObjwN6_zS54BOALM_aO6TbLf-7SfcokAqRRC9fUbdErDeuadKBuRq7ihEootiLodZoYLKtZVDAgTI1ZrxviY7SI1PcFm5/s16000/image8.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparison of various graph encoder functions based on their accuracy on different graph tasks. The main conclusion from this figure is that the graph encoding functions matter significantly.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Bigger is (usually) better &lt;/h3&gt;


&lt;p&gt;
In this experiment, we wanted to see if the size of the LLM (in terms of the number of parameters) affects how well they can handle graph problems. For that, we tested the same graph tasks on the XXS, XS, S, and L sizes of &lt;a href="https://ai.google/static/documents/palm2techreport.pdf"&gt;PaLM 2&lt;/a&gt;. Here is a summary of our findings:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;In general, bigger models did better on graph reasoning tasks. It seems like the extra parameters gave them space to learn more complex patterns.

&lt;/li&gt;&lt;li&gt;Oddly, size didn't matter as much for the “edge existence” task (finding out if two nodes in a graph are connected).

&lt;/li&gt;&lt;li&gt;Even the biggest LLM couldn't consistently beat a simple baseline solution on the cycle check problem (finding out if a graph contains a cycle or not). This shows LLMs still have room to improve with certain graph tasks.
&lt;/li&gt;
&lt;/ul&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiG-zu3s3K3iCIV5k2gakpMwQ_38a08_NrYeO3yITJc64EYiK36sksPulORuZR_BrGdmxZmCWEgIX2sWc42M4f3jpo8v17AddfoORPliE-SefptA4h4gye_g_PBKnufZ9kzTkI0f9MCKwSvuEqfcdgxNiycB2bGUQyUtXx8F7XU4qpXKZGEINZudJxlu-6L/s1227/image3.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="959" data-original-width="1227" height="500" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiG-zu3s3K3iCIV5k2gakpMwQ_38a08_NrYeO3yITJc64EYiK36sksPulORuZR_BrGdmxZmCWEgIX2sWc42M4f3jpo8v17AddfoORPliE-SefptA4h4gye_g_PBKnufZ9kzTkI0f9MCKwSvuEqfcdgxNiycB2bGUQyUtXx8F7XU4qpXKZGEINZudJxlu-6L/w640-h500/image3.png" width="640" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Effect of model capacity on graph reasoning task for PaLM 2-XXS, XS, S, and L.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Do different graph shapes confuse LLMs &lt;/h3&gt;


&lt;p&gt;
We wondered if the "shape" of a graph (how nodes are connected) influences how well LLMs can solve problems on it. Think of the following figure as different examples of graph shapes.
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgQ9tU8x8LvDYvwwN9XL4j64tXEq-7fGwnzYvS5zpNcEjk9yjxLH2yYmOAfKwr7_w9dHTUD1xtnI6IMAswp0pyManGDEO1ej1WeH9yByu-5ivtlfU5N-7OWJDtnR1uMeG7oWs1eqyiZFOyUpUa5GddPtECkd4ZvNPSx9rtS8fh83ahArgXtpKtVy7tQES9N/s1400/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="195" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgQ9tU8x8LvDYvwwN9XL4j64tXEq-7fGwnzYvS5zpNcEjk9yjxLH2yYmOAfKwr7_w9dHTUD1xtnI6IMAswp0pyManGDEO1ej1WeH9yByu-5ivtlfU5N-7OWJDtnR1uMeG7oWs1eqyiZFOyUpUa5GddPtECkd4ZvNPSx9rtS8fh83ahArgXtpKtVy7tQES9N/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Samples of graphs generated with different graph generators from GraphQA. ER, BA, SBM, and SFN refers to &lt;a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model"&gt;Erdős–Rényi&lt;/a&gt;, &lt;a href="https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model"&gt;Barabási–Albert&lt;/a&gt;, &lt;a href="https://en.wikipedia.org/wiki/Stochastic_block_model"&gt;Stochastic Block Model&lt;/a&gt;, and &lt;a href="https://en.wikipedia.org/wiki/Scale-free_network"&gt;Scale-Free Network&lt;/a&gt; respectively.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
We found that graph structure has a big impact on LLM performance. For example, in a task asking if a cycle exists, LLMs did great on tightly interconnected graphs (cycles are common there) but struggled on path graphs (where cycles never happen). Interestingly, providing some mixed examples helped it adapt. For instance, for cycle check, we added some examples containing a cycle and some examples with no cycles as few-shot examples in our prompt. Similar patterns occurred with other tasks.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgqf5piX3TuSfL0DqpUG7ZkBgMKEqqeCIh0feFG4ddMiaHTFgLY3iPkI4UD3gZpAKeTHgfhItKeXo8P3M4sGSQRZJJsXMAVFutTDuWziSwt1CBvt7kV1VSOSHqGTu0yk7lAym4XYJERrS3FETWbj17agumgHaln1EevI_LyzqAbNFZjYNPZGKjw1fgKBydk/s1864/image5.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1152" data-original-width="1864" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgqf5piX3TuSfL0DqpUG7ZkBgMKEqqeCIh0feFG4ddMiaHTFgLY3iPkI4UD3gZpAKeTHgfhItKeXo8P3M4sGSQRZJJsXMAVFutTDuWziSwt1CBvt7kV1VSOSHqGTu0yk7lAym4XYJERrS3FETWbj17agumgHaln1EevI_LyzqAbNFZjYNPZGKjw1fgKBydk/s16000/image5.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparing different graph generators on different graph tasks. The main observation here is that graph structure has a significant impact on the LLM’s performance. ER, BA, SBM, and SFN refers to &lt;a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model"&gt;Erdős–Rényi&lt;/a&gt;, &lt;a href="https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model"&gt;Barabási–Albert&lt;/a&gt;, &lt;a href="https://en.wikipedia.org/wiki/Stochastic_block_model"&gt;Stochastic Block Model&lt;/a&gt;, and &lt;a href="https://en.wikipedia.org/wiki/Scale-free_network"&gt;Scale-Free Network&lt;/a&gt; respectively.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
In short, we dug deep into how to best represent graphs as text so LLMs can understand them. We found three major factors that make a difference:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;&lt;em&gt;How to translate the graph to text&lt;/em&gt;: how we represent the graph as text significantly influences LLM performance. The incident encoding excelled for most of the tasks in general..

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Task type&lt;/em&gt;: Certain types of graph questions tend to be harder for LLMs, even with a good translation from graph to text.

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Graph structure&lt;/em&gt;: Surprisingly, the "shape" of the graph that on which we do inference (dense with connections, sparse, etc.) influences how well an LLM does.
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
This study revealed key insights about how to prepare graphs for LLMs. The right encoding techniques can significantly boost an LLM's accuracy on graph problems (ranging from around 5% to over 60% improvement). Our new benchmark, GraphQA, will help drive further research in this area.
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We would like to express our gratitude to our co-author, Jonathan Halcrow, for his valuable contributions to this work. We express our sincere gratitude to Anton Tsitsulin, Dustin Zelle, Silvio Lattanzi, Vahab Mirrokni, and the entire graph mining team at Google Research, for their insightful comments, thorough proofreading, and constructive feedback which greatly enhanced the quality of our work. We would also like to extend special thanks to Tom Small for creating the animation used in this post.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/6090481872694489715/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/talk-like-graph-encoding-graphs-for.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/6090481872694489715" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/6090481872694489715" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/talk-like-graph-encoding-graphs-for.html" rel="alternate" title="Talk like a graph: Encoding graphs for large language models" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg8L7r_SCzFsKsWegtrn8_EOoO2imefs-V_GVHzbM0Xw7GmAxoXIIX0RtpJ2JvloeenxcKCNmhCH_VXRMpu8b5dJP39UkhMJS0wP86TUftZtUi-hfj6tZdVEn30MZAeQEx762q1vN-q4DWP2EdOBIHy_CgNFMcliaJYnzxZHjnuifbVWy52zlls20m4BkyJ/s72-c/Screenshot%202024-03-12%20at%202.18.27%E2%80%AFPM.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-470840348983280912</id><published>2024-03-11T12:08:00.000-07:00</published><updated>2024-03-11T12:13:03.824-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Machine Intelligence"/><category scheme="http://www.blogger.com/atom/ns#" term="Natural Language Processing"/><title type="text">Chain-of-table: Evolving tables in the reasoning chain for table understanding</title><content type="html">&lt;span class="byline-author"&gt;Posted by Zilong Wang, Student Researcher, and Chen-Yu Lee, Research Scientist, Cloud AI Team&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg1smBN07qkS32Aop4if0AeINQQea0Grv8dw7GiRFBNoHBlgkkftynVBNjO6BckpF4vq8d0VqC1v0LoeFAVFqOLrBGlqvMNiCMUtIhHxVvsBjbPxvZLcNcD_Sa1sI_bDlqDLWn_C39MbPNm8VUjr2vhTBuaL4qCc1LUB1VH5iM0UVsswIWWq_uQg88YRWmb/s832/Chain-of-Table.png" style="display: none;" /&gt;

&lt;p&gt;
People use tables every day to organize and interpret complex information in a structured, easily accessible format. Due to the ubiquity of such tables, reasoning over tabular data has long been a central topic in &lt;a href="https://en.wikipedia.org/wiki/Natural_language_processing"&gt;natural language processing&lt;/a&gt; (NLP). Researchers in this field have aimed to leverage language models to help users answer questions, verify statements, and analyze data based on tables. However, language models are trained over large amounts of plain text, so the inherently structured nature of tabular data can be difficult for language models to fully comprehend and utilize.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
Recently, &lt;a href="https://en.wikipedia.org/wiki/Large_language_model"&gt;large language models&lt;/a&gt; (LLMs) have achieved outstanding performance across diverse &lt;a href="https://en.wikipedia.org/wiki/Natural-language_understanding"&gt;natural language understanding&lt;/a&gt; (NLU) tasks by generating reliable reasoning chains, as shown in works like &lt;a href="https://arxiv.org/abs/2201.11903"&gt;Chain-of-Thought&lt;/a&gt; and &lt;a href="https://arxiv.org/abs/2205.10625"&gt;Least-to-Most&lt;/a&gt;. However, the most suitable way for LLMs to reason over tabular data remains an open question.
&lt;/p&gt;
&lt;p&gt;
In “&lt;a href="https://arxiv.org/abs/2401.04398"&gt;Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding&lt;/a&gt;”, we propose a framework to tackle table understanding tasks, where we train LLMs to outline their reasoning step by step, updating a given table iteratively to reflect each part of a thought process, akin to how people solve the table-based problems. This enables the LLM to transform the table into simpler and more manageable segments so that it can understand and analyze each part of the table in depth. This approach has yielded significant improvements and achieved new state-of-the-art results on the &lt;a href="https://arxiv.org/abs/1508.00305"&gt;WikiTQ&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/1909.02164"&gt;TabFact&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2104.00369"&gt;FeTaQA&lt;/a&gt; benchmarks. The figure below shows the high-level overview of the proposed Chain-of-Table and other methods.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKT_df1rC8nK-ULOLPjtJ8gFaDHzRi7DX92Ix7OboQhOUNvqh_Melp9SVRWEsgL1Vu6IX9RuMgX7_UIuyeuHr7H0YwJdo6om2M2rX5d9wqOWsXWVAa9o0S75bIt7qG2DiGlhYypk0KKBMSxz2Z8vgmQqxTvy3bVrmH4nSC4Nzv8fZm6mOoA5yEXN_CgC4h/s1478/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="964" data-original-width="1478" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKT_df1rC8nK-ULOLPjtJ8gFaDHzRi7DX92Ix7OboQhOUNvqh_Melp9SVRWEsgL1Vu6IX9RuMgX7_UIuyeuHr7H0YwJdo6om2M2rX5d9wqOWsXWVAa9o0S75bIt7qG2DiGlhYypk0KKBMSxz2Z8vgmQqxTvy3bVrmH4nSC4Nzv8fZm6mOoA5yEXN_CgC4h/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Given a complex table where a cyclist’s nationality and name are in the same cell, (a) generic, multi-step reasoning is unable to provide the correct answer (b) program-aided reasoning generates and executes programs (e.g., SQL queries) to deliver the answer, but falls short in accurately addressing the question. In contrast, (c) Chain-of-Table iteratively samples a chain of operations that effectively transform the complex table into a version specifically tailored to the question.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;h2&gt;Chain-of-Table&lt;/h2&gt;


&lt;p&gt;
In Chain-of-Table, we guide LLMs using &lt;a href="https://arxiv.org/abs/2005.14165"&gt;in-context learning&lt;/a&gt; to iteratively generate operations and to update the table to represent its reasoning chain over tabular data. This enables LLMs to dynamically plan the next operation based on the results of previous ones. This continuous evolution of the table forms a chain, which provides a more structured and clear representation of the reasoning process for a given problem and enables more accurate and reliable predictions from the LLM. 
&lt;/p&gt;
&lt;p&gt;
For example, when asked, “Which actor has the most NAACP image awards?” the Chain-of-Table framework prompts an LLM to generate tabular operations mirroring tabular reasoning processes. It first identifies the relevant columns. Then, it aggregates rows based on shared content. Finally, it reorders the aggregated results to yield a final table that clearly answers the posed question. 
&lt;/p&gt;
&lt;p&gt;
These operations transform the table to align with the question presented. To balance performance with computational expense on large tables, we construct the operation chain according to a subset of tabular rows.. Meanwhile, the step-by-step operations reveal the underlying reasoning process through the display of intermediate results from the tabular operations, fostering enhanced interpretability and understanding.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj8JwNNHW6SR1PSRTj79oQKqE1K48onxbcM9uwIlacEGnUqtua0jgkXQ-CfyUukJ0qiBhqsKl1_YfeJmcqkMEe5TR08eo9ZEqymWYszwNyKfZjcx0T-wYwEnHqCvdlf9lJAG8UTBN6RZQngH7sv0hQ9szR1wgjyiFSaOIqVHC08bJv6HeaXvWJMHH41wI4_/s1999/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="983" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj8JwNNHW6SR1PSRTj79oQKqE1K48onxbcM9uwIlacEGnUqtua0jgkXQ-CfyUukJ0qiBhqsKl1_YfeJmcqkMEe5TR08eo9ZEqymWYszwNyKfZjcx0T-wYwEnHqCvdlf9lJAG8UTBN6RZQngH7sv0hQ9szR1wgjyiFSaOIqVHC08bJv6HeaXvWJMHH41wI4_/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Illustration of the tabular reasoning process in Chain-of-Table. This iterative process involves dynamically planning an operation chain and accurately storing intermediate results in the transformed tables. These intermediate tables serve as a tabular thought process that can guide the LLM to land to the correct answer more reliably.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;p&gt;
Chain-of-Table consists of three main stages. In the first stage, it instructs the LLM to dynamically plan the next operation by in-context learning. Specifically, the prompt involves three components as shown in the following figure: 
&lt;/p&gt;
&lt;ol&gt;

&lt;li&gt;  The question &lt;em&gt;Q&lt;/em&gt;: “Which country had the most cyclists finish in the top 3?”

&lt;/li&gt;&lt;li&gt;  The operation history &lt;em&gt;chain&lt;/em&gt;: &lt;code&gt;f_add_col(Country)&lt;/code&gt; and &lt;code&gt;f_select_row(1, 2, 3)&lt;/code&gt;.

&lt;/li&gt;&lt;li&gt;  The latest intermediate table &lt;em&gt;T&lt;/em&gt;: the transformed intermediate table. 
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;
By providing the triplet &lt;em&gt;(T, Q, chain)&lt;/em&gt; in the prompt, the LLM can observe the previous tabular reasoning process and select the next operation from the operation pool to complete the reasoning chain step by step.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIBKUfxjF1_KB5gtaj8DRWoqLWQKe_DJXLV6-1sClG1oKutdKujDHyzYgvGlAhQDK235cBoKwNkj7cuA4kLzCt_sltdiyuZSMmEKdEoDS7_XkOFTujyekDI8gJfSLRZkT5yIdGPCVvEVQPoueDgK7dXgyAs04fK3AuwSMurECyNc3ywvzDLAyoNjobg0zk/s1958/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1233" data-original-width="1958" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIBKUfxjF1_KB5gtaj8DRWoqLWQKe_DJXLV6-1sClG1oKutdKujDHyzYgvGlAhQDK235cBoKwNkj7cuA4kLzCt_sltdiyuZSMmEKdEoDS7_XkOFTujyekDI8gJfSLRZkT5yIdGPCVvEVQPoueDgK7dXgyAs04fK3AuwSMurECyNc3ywvzDLAyoNjobg0zk/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Illustration of how Chain-of-Table selects the next operation from the operation pool and generates the arguments for the operation.(a) Chain-of-Table samples the next operation from the operation pool. (b) It takes the selected operation as input and generates its arguments.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;p&gt;
After the next operation &lt;em&gt;f&lt;/em&gt; is determined, in the second stage, we need to generate the arguments. As above, Chain-of-Table considers three components in the prompt as shown in the figure: (1) the question, (2) the selected operation and its required arguments, and (3) the latest intermediate table.  
&lt;/p&gt;
&lt;p&gt;
For instance, when the operation &lt;code&gt;f_group_by&lt;/code&gt; is selected, it requires a header name as its argument. 
&lt;/p&gt;
&lt;p&gt;
The LLM selects a suitable header within the table. Equipped with the selected operation and the generated arguments, Chain-of-Table executes the operation and constructs a new intermediate table for the following reasoning.
&lt;/p&gt;
&lt;p&gt;
Chain-of-Table iterates the previous two stages to plan the next operation and generate the required arguments. During this process, we create an operation chain acting as a proxy for the  tabular reasoning steps. These operations generate intermediate tables presenting the results of each step to the LLM. Consequently, the output table contains comprehensive information about the intermediate phases of tabular reasoning. In our final stage, we employ this output table in formulating the final query and prompt the LLM along with the question for the final answer.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Experimental setup&lt;/h2&gt;


&lt;p&gt;
We use &lt;a href="https://ai.google/discover/palm2/"&gt;PaLM 2-S&lt;/a&gt;&amp;nbsp;and&amp;nbsp;&lt;a href="https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates"&gt;GPT 3.5&lt;/a&gt;&amp;nbsp;as the backbone LLMs and conduct the experiments on three public table understanding benchmarks: &lt;a href="https://arxiv.org/abs/1508.00305"&gt;WikiTQ&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/1909.02164"&gt;TabFact&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2104.00369"&gt;FeTaQA&lt;/a&gt;. WikiTQ and FeTaQA are datasets for table-based question answering. TabFact is a table-based fact verification benchmark. In this blogpost, we will focus on the results on WikiTQ and TabFact. We compare Chain-of-Table with the generic reasoning methods (e.g., End-to-End QA, Few-Shot QA, and &lt;a href="https://arxiv.org/abs/2201.11903"&gt;Chain-of-Thought&lt;/a&gt;) and the program-aided methods (e.g., &lt;a href="https://arxiv.org/abs/2204.00498"&gt;Text-to-SQL&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2210.02875"&gt;Binder&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2301.13808"&gt;Dater&lt;/a&gt;). 
&lt;/p&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;

&lt;h3&gt;More accurate answers&lt;/h3&gt;


&lt;p&gt;
Compared to the generic reasoning methods and program-aided reasoning methods, Chain-of-Table achieves better performance across &lt;a href="https://ai.google/discover/palm2/"&gt;PaLM 2&lt;/a&gt;&amp;nbsp;and&amp;nbsp;&lt;a href="https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates"&gt;GPT 3.5&lt;/a&gt;. This is attributed to the dynamically sampled operations and the informative intermediate tables.
&lt;/p&gt;&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglv7DlLdRCDhXh2D8EE8DaOlnyYOBET9usjjD4jQkBMDH_sdWzf72QL6qo8F6wXP6ThhxggSjh-F-z0aah7Qr36ghB3muAAn2k0cjfKV9hBSRaIooRI30qkAbn9nft00DNKG0WjCfVxyNYGD3AciTo282wQDItTceKuDKo03KGTOWvm76HXK2PGgQM8h5o/s1018/ChainOfTableUnderstanding.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="755" data-original-width="1018" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglv7DlLdRCDhXh2D8EE8DaOlnyYOBET9usjjD4jQkBMDH_sdWzf72QL6qo8F6wXP6ThhxggSjh-F-z0aah7Qr36ghB3muAAn2k0cjfKV9hBSRaIooRI30qkAbn9nft00DNKG0WjCfVxyNYGD3AciTo282wQDItTceKuDKo03KGTOWvm76HXK2PGgQM8h5o/s16000/ChainOfTableUnderstanding.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;&lt;span style="text-align: left;"&gt;Understanding results on WikiTQ and TabFact with PaLM 2 and GPT 3.5 compared with various models.&lt;/span&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;

&lt;h3&gt;Better robustness on harder questions&lt;/h3&gt;


&lt;p&gt;
In Chain-of-Table, longer operation chains indicate the higher difficulty and complexity of the questions and their corresponding tables. We categorize the test samples according to their operation lengths in Chain-of-Table. We compare Chain-of-Table with Chain-of-Thought and Dater, as representative generic and program-aided reasoning methods. We illustrate this using results from &lt;a href="https://ai.google/discover/palm2/"&gt;PaLM 2&lt;/a&gt; on &lt;a href="https://arxiv.org/abs/1508.00305"&gt;WikiTQ&lt;/a&gt;. 
&lt;/p&gt;&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhONWxPX_gDzAJe0m3HLMjtdzFZ_EF_uCEvpxlMdex5KpSeo2iUzAzyETzzPEl8wbbawjtmw5JbVYXWSEjkwq-198INrSZEzXlLIly40_nr65KOcgQA96rC8Pz744FQaWdTfeIFbeBO6uhPD4NmOeU1dYUzXeoPUlNk2vZ4zd4JVB6TNIaEsHJohvlrSna7/s1548/CoTOpChainLength.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="624" data-original-width="1548" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhONWxPX_gDzAJe0m3HLMjtdzFZ_EF_uCEvpxlMdex5KpSeo2iUzAzyETzzPEl8wbbawjtmw5JbVYXWSEjkwq-198INrSZEzXlLIly40_nr65KOcgQA96rC8Pz744FQaWdTfeIFbeBO6uhPD4NmOeU1dYUzXeoPUlNk2vZ4zd4JVB6TNIaEsHJohvlrSna7/s16000/CoTOpChainLength.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Performance of Chain-of-Thought, Dater, and the proposed Chain-of-Table on WikiTQ for questions that require an operation chain of varying lengths. Our proposed atomic operations significantly improve performance over generic and program-aided reasoning counterparts.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;p&gt;
Notably, Chain-of-Table consistently surpasses both baseline methods across all operation chain lengths, with a significant margin up to 11.6% compared with &lt;a href="https://arxiv.org/abs/2201.11903"&gt;Chain-of-Thought&lt;/a&gt;, and up to 7.9% compared with &lt;a href="https://arxiv.org/abs/2301.13808"&gt;Dater&lt;/a&gt;. Moreover, the performance of Chain-of-Table declines gracefully with increasing number of operations compared to other baseline methods, exhibiting only a minimal decrease when the number of operations increases from four to five.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Better robustness with larger tables&lt;/h3&gt;


&lt;p&gt;
We categorize the tables from &lt;a href="https://arxiv.org/abs/1508.00305"&gt;WikiTQ&lt;/a&gt; into three groups based on token number: small (&amp;lt;2000 tokens), medium (2000 to 4000 tokens) and large (&amp;gt;4000 tokens). We then compare Chain-of-Table with &lt;a href="https://arxiv.org/abs/2301.13808"&gt;Dater&lt;/a&gt; and &lt;a href="https://arxiv.org/abs/2210.02875"&gt;Binder&lt;/a&gt;, the two latest and strongest baselines. 
&lt;/p&gt;&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg_SgXNNWocZbCKKXAju3cpc4r-cABNL8zsrRmXJYPTiS68R8GM3lkTdxJPXoT3niFVX1bvmL9_QHrozVdl4_vYCamVsaixakttU_-ha88xZhHSbg6M_I4VgG86iynnNwv9ywdcbh5vFtqTKAs2kMmFGZNx85WBM5-RBxI63vvMfau7WbLSkqA7yrOIguY_/s1999/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1008" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg_SgXNNWocZbCKKXAju3cpc4r-cABNL8zsrRmXJYPTiS68R8GM3lkTdxJPXoT3niFVX1bvmL9_QHrozVdl4_vYCamVsaixakttU_-ha88xZhHSbg6M_I4VgG86iynnNwv9ywdcbh5vFtqTKAs2kMmFGZNx85WBM5-RBxI63vvMfau7WbLSkqA7yrOIguY_/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;&lt;span style="text-align: left;"&gt;Performance of Binder, Dater, and the proposed Chain-of-Table on small (&amp;lt;2000 tokens), medium (2000 to 4000 tokens), and large (&amp;gt;4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)&lt;/span&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;&lt;br /&gt;

&lt;p&gt;
Performance of Binder, Dater, and the proposed Chain-of-Table on small (&amp;lt;2000 tokens), medium (2000 to 4000 tokens), and large (&amp;gt;4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)
&lt;/p&gt;
&lt;p&gt;
As anticipated, the performance decreases with larger input tables, as models are required to reason through longer contexts. Nevertheless, the performance of the proposed Chain-of-Table diminishes gracefully, achieving a significant 10+% improvement over the second best competing method when dealing with large tables. This demonstrates the efficacy of the reasoning chain in handling long tabular inputs.
&lt;/p&gt;
&lt;br /&gt;

&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
Our proposed Chain-of-Table method enhances the reasoning capability of LLMs by leveraging the tabular structure to express intermediate steps for table-based reasoning. It instructs LLMs to dynamically plan an operation chain according to the input table and its associated question. This evolving table design sheds new light on the understanding of prompting LLMs for table understanding.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;This research was conducted by Zilong Wang, Hao Zhang, Chun-Liang Li, Julian Martin Eisenschlos, Vincent Perot, Zifeng Wang, Lesly Miculicich, Yasuhisa Fujii, Jingbo Shang, Chen-Yu Lee, Tomas Pfister. Thanks to Chih-Kuan Yeh and Sergey Ioffe for their valuable feedback.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/470840348983280912/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/chain-of-table-evolving-tables-in.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/470840348983280912" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/470840348983280912" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/chain-of-table-evolving-tables-in.html" rel="alternate" title="Chain-of-table: Evolving tables in the reasoning chain for table understanding" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg1smBN07qkS32Aop4if0AeINQQea0Grv8dw7GiRFBNoHBlgkkftynVBNjO6BckpF4vq8d0VqC1v0LoeFAVFqOLrBGlqvMNiCMUtIhHxVvsBjbPxvZLcNcD_Sa1sI_bDlqDLWn_C39MbPNm8VUjr2vhTBuaL4qCc1LUB1VH5iM0UVsswIWWq_uQg88YRWmb/s72-c/Chain-of-Table.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-1106624361649572376</id><published>2024-03-08T11:33:00.000-08:00</published><updated>2024-03-13T09:18:01.747-07:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Computer Vision"/><category scheme="http://www.blogger.com/atom/ns#" term="Health"/><category scheme="http://www.blogger.com/atom/ns#" term="Image Classification"/><title type="text">Health-specific embedding tools for dermatology and pathology</title><content type="html">&lt;span class="byline-author"&gt;Posted by Dave Steiner, Clinical Research Scientist, Google Health, and Rory Pilgrim, Product Manager, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi9zSpggPrlQvV-c0Lc2Sd79B58CwY0kDPJjgQfh-2SR8kiZuXO9A7LWZQ80zCqDNkYHm_IyNSQXF9xUOS-vPg8eJxkPR6HHuFr2VxoaAiAeG4J4ca6Pl8s9Jx1VX3tjQR0oA3I-oS2WujNwYJ2esmlfcyu1PZp7vh5MawdQc8Iu9aLM4fkAhycOXmumoKp/s16000/Path%20+%20Derm%20hero.jpg" style="display: none;" /&gt;

&lt;p&gt;
There’s a worldwide shortage of access to medical imaging expert interpretation across specialties including &lt;a href="https://www.rsna.org/news/2022/may/Global-Radiologist-Shortage"&gt;radiology&lt;/a&gt;, &lt;a href="https://www.aad.org/dw/monthly/2021/december/feature-running-dry"&gt;dermatology&lt;/a&gt; and &lt;a href="https://proscia.com/infographic-the-state-of-the-pathology-workforce-2022/"&gt;pathology&lt;/a&gt;. Machine learning (ML) technology can help ease this burden by powering tools that enable doctors to interpret these images more accurately and efficiently. However, the development and implementation of such ML tools are often limited by the availability of high-quality data, ML expertise, and computational resources. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
One way to catalyze the use of ML for medical imaging is via domain-specific models that utilize deep learning (DL) to capture the information in medical images as compressed numerical vectors (called embeddings). These embeddings represent a type of pre-learned understanding of the important features in an image. Identifying patterns in the embeddings reduces the amount of data, expertise, and compute needed to train performant models as compared to &lt;a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality"&gt;working with high-dimensional data&lt;/a&gt;, such as images, directly. Indeed, these embeddings can be used to perform a variety of downstream tasks within the specialized domain (see animated graphic below). This framework of leveraging pre-learned understanding to solve related tasks is similar to that of a seasoned guitar player quickly learning a new song by ear. Because the guitar player has already built up a foundation of skill and understanding, they can quickly pick up the patterns and groove of a new song. 
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglGgLglQSBBcJqiT_SsxQf9AKGyrenZw28xTiqVP9qljNyD8mhpv-m4kl27u4NLm0FGJShNOuK456JIzdQ269xBx3fBi1u2ke10iE4THphEkD9MCCGrHjhrddtAHJ27g3pyznABW3i_CxTNkONPsH-BOcoFgS4A8tscJsJ42eD5XAHJ3FVzkfmltMzUKkq/s1600/Path%20+%20Derm%20train%20LP.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="500" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglGgLglQSBBcJqiT_SsxQf9AKGyrenZw28xTiqVP9qljNyD8mhpv-m4kl27u4NLm0FGJShNOuK456JIzdQ269xBx3fBi1u2ke10iE4THphEkD9MCCGrHjhrddtAHJ27g3pyznABW3i_CxTNkONPsH-BOcoFgS4A8tscJsJ42eD5XAHJ3FVzkfmltMzUKkq/s16000/Path%20+%20Derm%20train%20LP.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Path Foundation is used to convert a small dataset of (image, label) pairs into (embedding, label) pairs. These pairs can then be used to train a task-specific classifier using a linear probe, (i.e., a lightweight linear classifier) as represented in this graphic, or other types of models using the embeddings as input.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhZeZ25Ea3ZXz8hd6YWMkECnI0jCWnsorTZ0Ob97G-94OZfE3vVtq27pAAmZufyRHfRjUVag-ViN2bIchtZ0eCl5mUIHldWQ8e0lEJAQhYy_Ae3JTCh9Sjc2izTny5I1fo5QxxZTzwvvIKzXNNugSpyYVnUplnm54zRNRKf38EhDU4hEcHYuqqbHdlxQyyz/s1600/Path%20+%20Derm%20-%20evaluate%20LP.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="500" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhZeZ25Ea3ZXz8hd6YWMkECnI0jCWnsorTZ0Ob97G-94OZfE3vVtq27pAAmZufyRHfRjUVag-ViN2bIchtZ0eCl5mUIHldWQ8e0lEJAQhYy_Ae3JTCh9Sjc2izTny5I1fo5QxxZTzwvvIKzXNNugSpyYVnUplnm54zRNRKf38EhDU4hEcHYuqqbHdlxQyyz/s16000/Path%20+%20Derm%20-%20evaluate%20LP.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Once the linear probe is trained, it can be used to make predictions on embeddings from new images. These predictions can be compared to ground truth information in order to evaluate the linear probe's performance.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
In order to make this type of embedding model available and drive further development of ML tools in medical imaging, we are excited to release two domain-specific tools for research use: &lt;a href="https://github.com/Google-Health/imaging-research/tree/master/derm-foundation"&gt;Derm Foundation&lt;/a&gt; and &lt;a href="https://github.com/Google-Health/imaging-research/tree/master/path-foundation"&gt;Path Foundation&lt;/a&gt;. This follows on the strong response we’ve already received from researchers using the &lt;a href="https://blog.research.google/2022/07/simplified-transfer-learning-for-chest.html"&gt;CXR Foundation&lt;/a&gt; embedding tool for chest radiographs and represents a portion of our expanding research offerings across multiple medical-specialized modalities. These embedding tools take an image as input and produce a numerical vector (the embedding) that is specialized to the domains of dermatology and digital pathology images, respectively. By running a dataset of chest X-ray, dermatology, or pathology images through the respective embedding tool, researchers can obtain embeddings for their own images, and use these embeddings to quickly develop new models for their applications.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Path Foundation&lt;/h2&gt;


&lt;p&gt;
In “&lt;a href="https://arxiv.org/abs/2310.13259"&gt;Domain-specific optimization and diverse evaluation of self-supervised models for histopathology&lt;/a&gt;”, we showed that self-supervised learning (SSL) models for pathology images outperform traditional pre-training approaches and enable efficient training of classifiers for downstream tasks. This effort focused on &lt;a href="https://en.wikipedia.org/wiki/H%26E_stain"&gt;hematoxylin and eosin&lt;/a&gt; (H&amp;amp;E) stained slides, the principal tissue stain in diagnostic pathology that enables pathologists to visualize cellular features under a microscope. The performance of linear classifiers trained using the output of the SSL models matched that of prior DL models trained on orders of magnitude more labeled data. 
&lt;/p&gt;

&lt;p&gt;
Due to substantial differences between digital pathology images and “natural image” photos, this work involved several pathology-specific optimizations during model training. One key element is that  &lt;a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7522141/"&gt;whole-slide images&lt;/a&gt; (WSIs) in pathology can be 100,000 pixels across (thousands of times larger than typical smartphone photos) and are analyzed by experts at multiple magnifications (zoom levels). As such, the WSIs are typically broken down into smaller tiles or patches for computer vision and DL applications. The resulting images are information dense with cells or tissue structures distributed throughout the frame instead of having  distinct semantic objects or foreground vs. background variations, thus creating unique challenges for robust SSL and feature extraction. Additionally, physical (e.g., &lt;a href="https://en.wikipedia.org/wiki/Microtome"&gt;cutting&lt;/a&gt;) and chemical (e.g., &lt;a href="https://en.wikipedia.org/wiki/Fixation_(histology)"&gt;fixing&lt;/a&gt; and &lt;a href="https://en.wikipedia.org/wiki/Staining"&gt;staining&lt;/a&gt;) processes used to prepare the samples can influence image appearance dramatically. 
&lt;/p&gt;

&lt;p&gt;
Taking these important aspects into consideration, pathology-specific SSL optimizations included helping the model learn &lt;a href="https://arxiv.org/abs/2206.12694"&gt;stain-agnostic features&lt;/a&gt;, generalizing the model to patches from multiple magnifications, &lt;a href="https://blog.research.google/2020/02/generating-diverse-synthetic-medical.html"&gt;augmenting&lt;/a&gt; the data to mimic scanning and image post processing, and custom data balancing to improve input heterogeneity for SSL training. These approaches were extensively evaluated using a broad set of benchmark tasks involving 17 different tissue types over 12 different tasks. 
&lt;/p&gt;


&lt;p&gt;
Utilizing the vision transformer (&lt;a href="https://github.com/google-research/vision_transformer"&gt;ViT-S/16&lt;/a&gt;) architecture, Path Foundation was selected as the best performing model from the optimization and evaluation process described above (and illustrated in the figure below). This model thus provides an important balance between performance and model size to enable valuable and scalable use in generating embeddings over the many individual image patches of large pathology WSIs.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhG4jlO0GRCgYA3fe6CteF9PYvm3joBGIBPXakdWWaQ7ztTTBK36dmrtRpK1xoNVub8MTMvmCzkW0wfCCkYUH3fnvKk8hJb79o4vETQq0MhqS1JDBxWgYUwFkjtpnkgx5jBiDOxwovsfgqvpNzVGpz6CY6nTJzJgSgtuE2qDRzIb9O7fbHrhdNU1-IWPSXp/s1999/Path%20+%20Derm%20SSL.jpg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1097" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhG4jlO0GRCgYA3fe6CteF9PYvm3joBGIBPXakdWWaQ7ztTTBK36dmrtRpK1xoNVub8MTMvmCzkW0wfCCkYUH3fnvKk8hJb79o4vETQq0MhqS1JDBxWgYUwFkjtpnkgx5jBiDOxwovsfgqvpNzVGpz6CY6nTJzJgSgtuE2qDRzIb9O7fbHrhdNU1-IWPSXp/s16000/Path%20+%20Derm%20SSL.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;SSL training with pathology-specific optimizations for Path Foundation.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
The value of domain-specific image representations can also be seen in the figure below, which shows the linear probing performance improvement of Path Foundation (as measured by &lt;a href="https://en.wikipedia.org/wiki/Receiver_operating_characteristic"&gt;AUROC&lt;/a&gt;) compared to traditional pre-training on natural images (&lt;a href="https://arxiv.org/abs/2104.10972"&gt;ImageNet-21k&lt;/a&gt;). This includes evaluation for tasks such as &lt;a href="https://jamanetwork.com/journals/jama/fullarticle/2665774"&gt;metastatic breast cancer detection in lymph nodes&lt;/a&gt;, &lt;a href="https://jamanetwork.com/journals/jamaoncology/fullarticle/2768225"&gt;prostate cancer grading&lt;/a&gt;, and &lt;a href="https://www.nature.com/articles/s41523-022-00478-y"&gt;breast cancer grading&lt;/a&gt;, among others. 
&lt;/p&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtMvTwce8mL0GYA3YTZP0Xc7ub_BYOHIvd9k4FAfnbd-XhpVFU3T9wAl7adebAGVYSWv0RraeV_NHj-0ZiVKQ94wUM9D6GzLSg-FU9ad_L5wN4lksjbWMhN_53FhuY0yGcFvYBU8AgTY7UJKm8z9vz-rH7wkr_m5TOY8gFjWh3YkxHcPMr1wLAkS4hnGkJ/s1999/Path%20+%20Derm%20embeddings.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="890" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtMvTwce8mL0GYA3YTZP0Xc7ub_BYOHIvd9k4FAfnbd-XhpVFU3T9wAl7adebAGVYSWv0RraeV_NHj-0ZiVKQ94wUM9D6GzLSg-FU9ad_L5wN4lksjbWMhN_53FhuY0yGcFvYBU8AgTY7UJKm8z9vz-rH7wkr_m5TOY8gFjWh3YkxHcPMr1wLAkS4hnGkJ/s16000/Path%20+%20Derm%20embeddings.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Path Foundation embeddings significantly outperform traditional ImageNet embeddings as evaluated by linear probing across multiple evaluation tasks in histopathology.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

 
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Derm Foundation&lt;/h2&gt;


&lt;p&gt;
&lt;a href="https://github.com/Google-Health/imaging-research/tree/master/derm-foundation"&gt;Derm Foundation&lt;/a&gt; is an embedding tool derived from our research in applying DL to &lt;a href="https://blog.research.google/2019/09/using-deep-learning-to-inform.html"&gt;interpret images of dermatology conditions&lt;/a&gt; and includes our recent work that adds &lt;a href="https://arxiv.org/abs/2402.15566"&gt;improvements to generalize better to new datasets&lt;/a&gt;. Due to its dermatology-specific pre-training it has a latent understanding of features present in images of skin conditions and can be used to quickly develop models to classify skin conditions. The model underlying the API is a &lt;a href="https://github.com/google-research/big_transfer"&gt;BiT ResNet-101x3&lt;/a&gt; trained in two stages. The first pre-training stage uses contrastive learning, similar to &lt;a href="https://arxiv.org/abs/2010.00747"&gt;ConVIRT&lt;/a&gt;, to train on a large number of image-text pairs &lt;a href="https://blog.research.google/2017/07/revisiting-unreasonable-effectiveness.html"&gt;from the internet&lt;/a&gt;. In the second stage, the image component of this pre-trained model is then fine-tuned for condition classification using clinical datasets, such as those from teledermatology services.
&lt;/p&gt;

&lt;p&gt;
Unlike histopathology images, dermatology images more closely resemble the real-world images used to train many of today's computer vision models. However, for specialized dermatology tasks, creating a high-quality model may still require a large dataset. With Derm Foundation, researchers can use their own smaller dataset to retrieve domain-specific embeddings, and use those to build smaller models (e.g., linear classifiers or other small non-linear models) that enable them to validate their research or product ideas. To evaluate this approach, we trained models on a downstream task using teledermatology data. Model training involved varying dataset sizes (12.5%, 25%, 50%, 100%) to compare embedding-based linear classifiers against fine-tuning.
&lt;/p&gt;

&lt;p&gt;
The modeling variants considered were:
&lt;/p&gt;

&lt;ul&gt;

&lt;li&gt;A linear classifier on frozen embeddings from &lt;a href="https://github.com/google-research/big_transfer"&gt;BiT-M&lt;/a&gt; (a standard pre-trained image model)

&lt;/li&gt;&lt;li&gt;Fine-tuned version of BiT-M with an extra dense layer for the downstream task

&lt;/li&gt;&lt;li&gt;A linear classifier on frozen embeddings from the Derm Foundation API

&lt;/li&gt;&lt;li&gt;Fine-tuned version of the model underlying the Derm Foundation API with an extra layer for the downstream task
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
We found that models built on top of the Derm Foundation embeddings for dermatology-related tasks achieved significantly higher quality than those built solely on embeddings or fine tuned from BiT-M. This advantage was found to be most pronounced for smaller training dataset sizes.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3cFSDBqVdsZm4MaFhMXli6kEJazYEB4xEYPB6ebOPv24HPd57Puw1zfu85raJ0gqfpnwsLW99Wh6aShuoCKZNYLw1PiG7eIqUEm8nMvwTy2qQTNL8ptn7cqBll127x_iEIsDMjznY5pWRIYF89cvBP3uPiVfMTgJS8aQpXiOC3oCO1Xl8CxTc4LXrLnjY/s1240/Path%20+%20Derm%20task%20accuracy.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="842" data-original-width="1240" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3cFSDBqVdsZm4MaFhMXli6kEJazYEB4xEYPB6ebOPv24HPd57Puw1zfu85raJ0gqfpnwsLW99Wh6aShuoCKZNYLw1PiG7eIqUEm8nMvwTy2qQTNL8ptn7cqBll127x_iEIsDMjznY5pWRIYF89cvBP3uPiVfMTgJS8aQpXiOC3oCO1Xl8CxTc4LXrLnjY/s16000/Path%20+%20Derm%20task%20accuracy.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;These results demonstrate that the Derm Foundation tooI can serve as a useful starting point to accelerate skin-related modeling tasks. We aim to enable other researchers to build on the underlying features and representations of dermatology that the model has learned. &lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
However, there are limitations with this analysis. We're still exploring how well these embeddings generalize across task types, patient populations, and image settings. Downstream models built using Derm Foundation still require careful evaluation to understand their expected performance in the intended setting.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Access Path and Derm Foundation&lt;/h2&gt;


&lt;p&gt;
We envision that the Derm Foundation and Path Foundation embedding tools will enable a range of use cases, including efficient development of models for diagnostic tasks, quality assurance and pre-analytical workflow improvements, image indexing and curation, and biomarker discovery and validation. We are releasing both tools to the research community so they can explore the utility of the embeddings for their own dermatology and pathology data.
&lt;/p&gt;

&lt;p&gt;
To get access, please sign up to each tool's terms of service using the following Google Forms. 
&lt;/p&gt;

&lt;ul&gt;

&lt;li&gt;&lt;a href="https://docs.google.com/forms/d/e/1FAIpQLSe5icNBzU_lO2CwjLLIOwbqIcWnJC-m4Sl7MgvI9Lng3QT6Zg/viewform?resourcekey=0-dahJtiVe2CqYkNEdWPcXgw"&gt;Derm Foundation Access Form&lt;/a&gt;

&lt;/li&gt;&lt;li&gt;&lt;a href="https://docs.google.com/forms/d/1auyo2VkzlzuiAXavZy1AWUyQHAqO7T3BLK-7ofKUvug/edit?resourcekey=0-Z9pRxjDI-kaDEUIiNfMAWQ#question=1168037695&amp;field=173852432"&gt;Path Foundation Access Form&lt;/a&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;
After gaining access to each tool, you can use the API to retrieve embeddings from dermatology images or digital pathology images stored in Google Cloud. Approved users who are just curious to see the model and embeddings in action can use the provided example Colab notebooks to train models using public data for classifying &lt;a href="https://github.com/Google-Health/imaging-research/blob/master/derm-foundation/derm_foundation_demo.ipynb"&gt;six common skin conditions&lt;/a&gt; or identifying tumors in &lt;a href="https://github.com/Google-Health/imaging-research/blob/master/path-foundation/linear-classifier-demo.ipynb"&gt;histopathology patches&lt;/a&gt;. We look forward to seeing the range of use-cases these tools can unlock.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We would like to thank the many collaborators who helped make this work possible including Yun Liu, Can Kirmizi, Fereshteh Mahvar, Bram Sterling, Arman Tajback, Kenneth Philbrik, Arnav Agharwal, Aurora Cheung, Andrew Sellergren, Boris Babenko, Basil Mustafa, Jan Freyberg, Terry Spitz, Yuan Liu, Pinal Bavishi, Ayush Jain, Amit Talreja, Rajeev Rikhye, Abbi Ward, Jeremy Lai, Faruk Ahmed, Supriya Vijay,Tiam Jaroensri, Jessica Loo, Saurabh Vyawahare, Saloni Agarwal, Ellery Wulczyn, Jonathan Krause, Fayaz Jamil, Tom Small, Annisah Um'rani, Lauren Winer, Sami Lachgar, Yossi Matias, Greg Corrado, and Dale Webster.&lt;/em&gt;
&lt;/p&gt;
</content><link href="http://blog.research.google/feeds/1106624361649572376/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/health-specific-embedding-tools-for.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1106624361649572376" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1106624361649572376" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/health-specific-embedding-tools-for.html" rel="alternate" title="Health-specific embedding tools for dermatology and pathology" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi9zSpggPrlQvV-c0Lc2Sd79B58CwY0kDPJjgQfh-2SR8kiZuXO9A7LWZQ80zCqDNkYHm_IyNSQXF9xUOS-vPg8eJxkPR6HHuFr2VxoaAiAeG4J4ca6Pl8s9Jx1VX3tjQR0oA3I-oS2WujNwYJ2esmlfcyu1PZp7vh5MawdQc8Iu9aLM4fkAhycOXmumoKp/s72-c/Path%20+%20Derm%20hero.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-1765359719068432739</id><published>2024-03-07T10:15:00.000-08:00</published><updated>2024-03-07T10:19:31.177-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Large Language Models"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Intelligence"/><category scheme="http://www.blogger.com/atom/ns#" term="Natural Language Processing"/><title type="text">Social learning: Collaborative learning with large language models</title><content type="html">&lt;span class="byline-author"&gt;Posted by Amirkeivan Mohtashami, Research Intern, and Florian Hartmann, Software Engineer, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEicN2GYOp9oUj5x0F20i550WuF5KmpD8iRqrdHmJFU_HmkdFY3RBF4mfn_99q8jtEPVm56a4NfjMGFJ79y3rygqjX46h23tlzSDde7iEbp8ytHsPa5-IsNKFFituSoPmtGk666gjyypTvVhhuin8FahZfhWPyDWqF5yWBIQ-Cf_DxQ7vrmTWIkA_tAJtm4v/s1999/image2.png" style="display: none;" /&gt;

&lt;p&gt;
Large language models (LLMs) have significantly improved the state of the art for solving tasks specified using natural language, often reaching performance close to that of people. As these models increasingly enable assistive agents, it could be beneficial for them to learn effectively from each other, much like people do in social settings, which would allow LLM-based agents to improve each other’s performance. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt; 

&lt;p&gt;
To discuss the learning processes of humans, Bandura and Walters &lt;a href="https://books.google.ch/books/about/Social_Learning_Theory.html?id=IXvuAAAAMAAJ&amp;amp;redir_esc=y"&gt;described&lt;/a&gt; the concept of &lt;em&gt;social learning&lt;/em&gt; in 1977, outlining different models of observational learning used by people. One common method of learning from others is through a &lt;em&gt;verbal instruction&lt;/em&gt; (e.g., from a teacher) that describes how to engage in a particular behavior. Alternatively, learning can happen through a &lt;em&gt;live model&lt;/em&gt; by mimicking a live example of the behavior.
&lt;/p&gt;
&lt;p&gt;
Given the success of LLMs mimicking human communication, in our paper “&lt;a href="https://arxiv.org/abs/2312.11441"&gt;Social Learning: Towards Collaborative Learning with Large Language Models&lt;/a&gt;”, we investigate whether LLMs are able to learn from each other using social learning. To this end, we outline a framework for social learning in which LLMs share knowledge with each other in a privacy-aware manner using natural language. We evaluate the effectiveness of our framework on various datasets, and propose quantitative methods that measure privacy in this setting. In contrast to previous approaches to collaborative learning, such as common &lt;a href="https://blog.research.google/2017/04/federated-learning-collaborative.html"&gt;federated learning&lt;/a&gt; approaches that often rely on gradients, in our framework, agents teach each other purely using natural language.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Social learning for LLMs&lt;/h2&gt;


&lt;p&gt;
To extend social learning to language models, we consider the scenario where a student LLM should learn to solve a task from multiple teacher entities that already know that task. In our paper, we evaluate the student’s performance on a variety of tasks, such as &lt;a href="https://dl.acm.org/doi/10.1145/2034691.2034742"&gt;spam detection&lt;/a&gt; in short text messages (SMS), solving &lt;a href="https://arxiv.org/abs/2110.14168"&gt;grade school math problems&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/1905.10044"&gt;answering questions&lt;/a&gt; based on a given text.   
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgAndq_MjAVBs4j3lmxEX71nMrCLpAasklndZyE8F7yj3slyafRsNauzW4yRxI_Ncg7Sp5jllAXpItsjA-BOmdB2O1jP3Awu09-DVRHBE_Urf58yzm5tDBBpM-aibZxmgA9O6CySCCRdSMMqG7vj-OU07jHa0OU0YixCxRB0Q3APMQbn8Vz5rEBp70ZNogH/s900/image3.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="381" data-original-width="900" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgAndq_MjAVBs4j3lmxEX71nMrCLpAasklndZyE8F7yj3slyafRsNauzW4yRxI_Ncg7Sp5jllAXpItsjA-BOmdB2O1jP3Awu09-DVRHBE_Urf58yzm5tDBBpM-aibZxmgA9O6CySCCRdSMMqG7vj-OU07jHa0OU0YixCxRB0Q3APMQbn8Vz5rEBp70ZNogH/s16000/image3.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;A visualization of the social learning process: A teacher model provides instructions or few-shot examples to a student model without sharing its private data.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
Language models have shown a remarkable capacity to perform tasks given only a handful of examples–a process called &lt;a href="https://arxiv.org/abs/2005.14165"&gt;few-shot learning&lt;/a&gt;. With this in mind, we provide human-labeled examples of a task that enables the teacher model to teach it to a student. One of the main use cases of social learning arises when these examples cannot be directly shared with the student due, for example, to privacy concerns. 
&lt;/p&gt;
&lt;p&gt;
To illustrate this, let’s look at a hypothetical example for a spam detection task. A teacher model is located on device where some users volunteer to mark incoming messages they receive as either “spam” or “not spam”. This is useful data that could help train a student model to differentiate between spam and not spam, but sharing personal messages with other users is a breach of privacy and should be avoided. To prevent this, a social learning process can transfer the knowledge from the teacher model to the student so it learns what spam messages look like without needing to share the user’s personal text messages.
&lt;/p&gt;
&lt;p&gt;
We investigate the effectiveness of this social learning approach by analogy with the established human social learning theory that we discussed above. In these experiments, we use &lt;a href="https://blog.google/technology/ai/google-palm-2-ai-large-language-model/"&gt;PaLM 2-S&lt;/a&gt; models for both the teacher and the student.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEicN2GYOp9oUj5x0F20i550WuF5KmpD8iRqrdHmJFU_HmkdFY3RBF4mfn_99q8jtEPVm56a4NfjMGFJ79y3rygqjX46h23tlzSDde7iEbp8ytHsPa5-IsNKFFituSoPmtGk666gjyypTvVhhuin8FahZfhWPyDWqF5yWBIQ-Cf_DxQ7vrmTWIkA_tAJtm4v/s1999/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1117" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEicN2GYOp9oUj5x0F20i550WuF5KmpD8iRqrdHmJFU_HmkdFY3RBF4mfn_99q8jtEPVm56a4NfjMGFJ79y3rygqjX46h23tlzSDde7iEbp8ytHsPa5-IsNKFFituSoPmtGk666gjyypTvVhhuin8FahZfhWPyDWqF5yWBIQ-Cf_DxQ7vrmTWIkA_tAJtm4v/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;A systems view of social learning: At training time, multiple teachers teach the student. At inference time, the student is using what it learned from the teachers.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;h3&gt;Synthetic examples&lt;/h3&gt;


&lt;p&gt;
As a counterpart to the live teaching model described for traditional social learning, we propose a learning method where the teachers generate new synthetic examples for the task and share them with the student. This is motivated by the idea that one can create a new example that is sufficiently different from the original one, but is just as educational. Indeed, we observe that our generated examples are sufficiently different from the real ones to preserve privacy while still enabling performance comparable to that achieved using the original examples.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiBGMoLyGVpCFO2DkG61pJJwjfje3CZO9V_5YfK3FJlQrbqD8P1RnBt70-G1p0ifTVZ8hnN0upKFdnbZNkPeKpICUiYU0uoqftlq-1bvLXfwlzPFhsCf4uyD5Z4z_ML44YWVf-pjyWEbgsgKGEp_P5F7QzFH3P5TokVfw1QQhD2dSON4dDp3jXqZTHXYZSd/s1456/image5.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="880" data-original-width="1456" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiBGMoLyGVpCFO2DkG61pJJwjfje3CZO9V_5YfK3FJlQrbqD8P1RnBt70-G1p0ifTVZ8hnN0upKFdnbZNkPeKpICUiYU0uoqftlq-1bvLXfwlzPFhsCf4uyD5Z4z_ML44YWVf-pjyWEbgsgKGEp_P5F7QzFH3P5TokVfw1QQhD2dSON4dDp3jXqZTHXYZSd/s16000/image5.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;The 8 generated examples perform as well as the original data for several tasks (see our&amp;nbsp;&lt;a href="https://arxiv.org/abs/2312.11441"&gt;paper&lt;/a&gt;).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
We evaluate the efficacy of learning through synthetic examples on our task suite. Especially when the number of examples is high enough, e.g., n = 16, we observe no statistically significant difference between sharing original data and teaching with synthesized data via social learning for the majority of tasks, indicating that the privacy improvement does not have to come at the cost of model quality. 
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhQPNMTVzgQW7O3o7Uz0a42vnT7kBhAjqRg5ZL1UrQVs7H5b5-FGdxJFcBmCGHr8sU3WkHsPKVlsQmVnzW-YAop1plz6oxYvTQyxEirorXE2WyGVfFvdOzAw5ydoMh7WUNykMJqasBqCr3C2n_pwBlAFZLO-WBiS-yXm9ExW_NTTIW8zYvfu17cMU8Y3_tp/s1456/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="880" data-original-width="1456" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhQPNMTVzgQW7O3o7Uz0a42vnT7kBhAjqRg5ZL1UrQVs7H5b5-FGdxJFcBmCGHr8sU3WkHsPKVlsQmVnzW-YAop1plz6oxYvTQyxEirorXE2WyGVfFvdOzAw5ydoMh7WUNykMJqasBqCr3C2n_pwBlAFZLO-WBiS-yXm9ExW_NTTIW8zYvfu17cMU8Y3_tp/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Generating 16 instead of just 8 examples further reduces the performance gap relative to the original examples.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;p&gt;
The one exception is spam detection, for which teaching with synthesized data yields lower accuracy. This may be because the training procedure of current models makes them biased to only generate non-spam examples. In the &lt;a href="https://arxiv.org/abs/2312.11441"&gt;paper&lt;/a&gt;, we additionally look into aggregation methods for selecting good subsets of examples to use.
&lt;/p&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;

&lt;h3&gt;Synthetic instruction&lt;/h3&gt;


&lt;p&gt;
Given the success of language models in following instructions, the verbal instruction model can also be naturally adapted to language models by having the teachers generate an instruction for the task. Our experiments show that providing such a generated instruction effectively improves performance over zero-shot prompting, reaching accuracies comparable to few-shot prompting with original examples. However, we did find that the teacher model may fail on certain tasks to provide a good instruction, for example due to a complicated formatting requirement of the output. 
&lt;/p&gt;
&lt;p&gt;
For &lt;a href="https://arxiv.org/abs/1606.06031"&gt;Lambada&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2110.14168"&gt;GSM8k&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2005.14165"&gt;Random Insertion&lt;/a&gt;, providing synthetic examples performs better than providing generated instructions, whereas in the other tasks generated instruction obtains a higher accuracy. This observation suggests that the choice of the teaching model depends on the task at hand, similar to how the most effective method for teaching people varies by task.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlmIYiQiqu5BGxrgWq6kklbYjnf3cEIE8lYcoIDQBYY54-ZQCTO2bm7IwpElQCD9ZX0Kt9_egKLhFjlmQFh-oJejJuLHHFDC-d_FVS9DzxGQNzEHy8nFL6BTs5D0evWbiDFjhy1p2OZ9u-QixTWFfP73SEWa2L5iax9OGFvwfuGvi5bsr2EzCSEUYONJ5r/s1451/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="880" data-original-width="1451" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlmIYiQiqu5BGxrgWq6kklbYjnf3cEIE8lYcoIDQBYY54-ZQCTO2bm7IwpElQCD9ZX0Kt9_egKLhFjlmQFh-oJejJuLHHFDC-d_FVS9DzxGQNzEHy8nFL6BTs5D0evWbiDFjhy1p2OZ9u-QixTWFfP73SEWa2L5iax9OGFvwfuGvi5bsr2EzCSEUYONJ5r/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Depending on the task, generating instructions can work better than generating new examples.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;

&lt;h2&gt;Memorization of the private examples&lt;/h2&gt;


&lt;p&gt;
We want teachers in social learning to teach the student without revealing specifics from the original data. To quantify how prone this process is to leaking information, we used &lt;a href="https://research.google/pubs/the-secret-sharer-evaluating-and-testing-unintended-memorization-in-neural-networks/"&gt;Secret Sharer&lt;/a&gt;, a popular method for quantifying to what extent a model memorizes its training data, and adapted it to the social learning setting. We picked this method since it had previously been &lt;a href="https://blog.research.google/2023/03/distributed-differential-privacy-for.html"&gt;used&lt;/a&gt; for evaluating memorization in federated learning.
&lt;/p&gt;
&lt;p&gt;
To apply the Secret Sharer method to social learning, we design “canary” data points such that we can concretely measure how much the training process memorized them. These data points are included in the datasets used by teachers to generate new examples. After the social learning process completes, we can then measure how much more confident the student is in the secret data points the teacher used, compared to similar ones that were not shared even with the teachers.
&lt;/p&gt;
&lt;p&gt;
In our analysis, discussed in detail in the &lt;a href="https://arxiv.org/abs/2312.11441"&gt;paper&lt;/a&gt;, we use canary examples that include names and codes. Our results show that the student is only slightly more confident in the canaries the teacher used. In contrast, when the original data points are directly shared with the student, the confidence in the included canaries is much higher than in the held-out set. This supports the conclusion that the teacher does indeed use its data to teach without simply copying it over.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Conclusion and next steps&lt;/h2&gt;


&lt;p&gt;
We introduced a framework for social learning that allows language models with access to private data to transfer knowledge through textual communication while maintaining the privacy of that data. In this framework, we identified sharing examples and sharing instructions as basic models and evaluated them on multiple tasks. Furthermore, we adapted the Secret Sharer metric to our framework, proposing a metric for measuring data leakage.
&lt;/p&gt;
&lt;p&gt;
As next steps, we are looking for ways of improving the teaching process, for example by adding feedback loops and iteration. Furthermore, we want to investigate using social learning for modalities other than text.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We would like to acknowledge and thank Matt Sharifi, Sian Gooding, Lukas Zilka, and Blaise Aguera y Arcas, who are all co-authors on the paper. Furthermore, we would like to thank Victor Cărbune, Zachary Garrett, Tautvydas Misiunas, Sofia Neata and John Platt for their feedback, which greatly improved the paper. We’d also like to thank Tom Small for creating the animated figure.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/1765359719068432739/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/social-learning-collaborative-learning.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1765359719068432739" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1765359719068432739" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/social-learning-collaborative-learning.html" rel="alternate" title="Social learning: Collaborative learning with large language models" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEicN2GYOp9oUj5x0F20i550WuF5KmpD8iRqrdHmJFU_HmkdFY3RBF4mfn_99q8jtEPVm56a4NfjMGFJ79y3rygqjX46h23tlzSDde7iEbp8ytHsPa5-IsNKFFituSoPmtGk666gjyypTvVhhuin8FahZfhWPyDWqF5yWBIQ-Cf_DxQ7vrmTWIkA_tAJtm4v/s72-c/image2.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-8393293208018757284</id><published>2024-03-06T10:26:00.000-08:00</published><updated>2024-03-06T14:44:03.387-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Collaboration"/><category scheme="http://www.blogger.com/atom/ns#" term="datasets"/><category scheme="http://www.blogger.com/atom/ns#" term="ML"/><title type="text">Croissant: a metadata format for ML-ready datasets</title><content type="html">&lt;span class="byline-author"&gt;Posted by Omar Benjelloun, Software Engineer, Google Research, and Peter Mattson, Software Engineer, Google Core ML and President, MLCommons Association&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj09uSTHgWmPgOkD9W1nZZj5i8uW_-pgxm-T1O5PSacF-EKvHIeIwhMr7Rgft7O3A2Rk94GWe8WboO3dUlxrqt1xz9x4I2aMKJxCUtUkR2eukbsIa8xVyAAN_LJJyMABxRqJuktFkyfhoWPDMQK3O-XgbQNJXzAILlWl3su0fd-Q_uZ-8r5r_uAU2P4srnP/s1600/CroissantHero.png" style="display: none;" /&gt;



&lt;p&gt;
Machine learning (ML) practitioners looking to reuse existing datasets to train an ML model often spend a lot of time understanding the data, making sense of its organization, or figuring out what subset to use as features. So much time, in fact, that progress in the field of ML is hampered by a fundamental obstacle: the wide variety of data representations. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;


&lt;p&gt;
ML datasets cover a broad range of content types, from text and structured data to images, audio, and video. Even within datasets that cover the same types of content, every dataset has a unique &lt;em&gt;ad hoc&lt;/em&gt; arrangement of files and data formats. This challenge reduces productivity throughout the entire ML development process, from finding the data to training the model. It also impedes development of badly needed tooling for working with datasets. 
&lt;/p&gt;
&lt;p&gt;
There are general purpose metadata formats for datasets such as &lt;a href="http://schema.org/Dataset"&gt;schema.org&lt;/a&gt; and &lt;a href="https://www.w3.org/TR/vocab-dcat-3/"&gt;DCAT&lt;/a&gt;. However, these formats were designed for data discovery rather than for the specific needs of ML data, such as the ability to extract and combine data from structured and unstructured sources, to include metadata that would enable &lt;a href="https://ai.google/responsibility/responsible-ai-practices/"&gt;responsible use&lt;/a&gt; of the data, or to describe ML usage characteristics such as defining training, test and validation sets. 
&lt;/p&gt;
&lt;p&gt;
Today, we're introducing &lt;a href="https://mlcommons.org/croissant"&gt;Croissant&lt;/a&gt;, a new metadata format for ML-ready datasets. Croissant was developed collaboratively by a community from industry and academia, as part of the &lt;a href="https://mlcommons.org/"&gt;MLCommons&lt;/a&gt; effort. The Croissant format doesn't change how the actual data is represented (e.g., image or text file formats) — it provides a standard way to describe and organize it. Croissant builds upon &lt;a href="https://schema.org/"&gt;schema.org&lt;/a&gt;, the de facto standard for publishing structured data on the Web, which is already used by over 40M datasets. Croissant augments it with comprehensive layers for ML relevant metadata, data resources, data organization, and default ML semantics.
&lt;/p&gt;
&lt;p&gt;
In addition, we are announcing support from major tools and repositories: Today, three widely used collections of ML datasets — &lt;a href="http://www.kaggle.com/datasets"&gt;Kaggle&lt;/a&gt;, &lt;a href="https://huggingface.co/datasets?other=croissant&amp;amp;sort=trending"&gt;Hugging Face&lt;/a&gt;, and &lt;a href="https://openml.org/search?type=data"&gt;OpenML&lt;/a&gt; — will begin supporting the Croissant format for the datasets they host; the &lt;a href="http://g.co/datasetsearch"&gt;Dataset Search&lt;/a&gt; tool lets users search for Croissant datasets across the Web; and popular ML frameworks, including &lt;a href="https://www.tensorflow.org/"&gt;TensorFlow&lt;/a&gt;, &lt;a href="https://pytorch.org/"&gt;PyTorch&lt;/a&gt;, and &lt;a href="https://github.com/google/jax"&gt;JAX&lt;/a&gt;, can load Croissant datasets easily using the &lt;a href="https://www.tensorflow.org/datasets"&gt;TensorFlow Datasets&lt;/a&gt; (TFDS) package.
&lt;/p&gt;


&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Croissant&lt;/h2&gt;


&lt;p&gt;
This 1.0 release of Croissant includes a complete &lt;a href="https://mlcommons.org/croissant/1.0"&gt;specification&lt;/a&gt; of the format, a set of &lt;a href="https://github.com/mlcommons/croissant/tree/main/datasets"&gt;example datasets&lt;/a&gt;, an open source &lt;a href="https://github.com/mlcommons/croissant/tree/main/python/mlcroissant"&gt;Python library&lt;/a&gt; to validate, consume and generate Croissant metadata, and an open source &lt;a href="https://github.com/mlcommons/croissant/tree/main/editor"&gt;visual editor&lt;/a&gt; to load, inspect and create Croissant dataset descriptions in an intuitive way.
&lt;/p&gt;
&lt;p&gt;
Supporting Responsible AI (RAI) was a key goal of the Croissant effort from the start. We are also releasing the first version of the &lt;a href="https://mlcommons.org/croissant/RAI/1.0"&gt;Croissant RAI vocabulary&lt;/a&gt; extension, which augments Croissant with key properties needed to describe important RAI use cases such as data life cycle management, data labeling, participatory data, ML safety and fairness evaluation, explainability, and compliance.
&lt;/p&gt;

&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Why a shared format for ML data?&lt;/h2&gt;
&lt;p&gt;
The majority of ML work is actually data work. The training data is the “code” that determines the behavior of a model. Datasets can vary from a collection of text used to train a large language model (LLM) to a collection of driving scenarios (annotated videos) used to train a car’s collision avoidance system. However, the steps to develop an ML model typically follow the same iterative data-centric process: (1) find or collect data, (2) clean and refine the data, (3) train the model on the data, (4) test the model on more data, (5) discover the model does not work, (6) analyze the data to find out why, (7) repeat until a workable model is achieved. Many steps are made harder by the lack of a common format. This “data development burden” is especially heavy for resource-limited research and early-stage entrepreneurial efforts. 
&lt;/p&gt;
&lt;p&gt;
The goal of a format like Croissant is to make this entire process easier. For instance, the metadata can be leveraged by search engines and dataset repositories to make it easier to find the right dataset. The data resources and organization information make it easier to develop tools for cleaning, refining, and analyzing data. This information and the default ML semantics make it possible for ML frameworks to use the data to train and test models with a minimum of code. Together, these improvements substantially reduce the data development burden.
&lt;/p&gt;
&lt;p&gt;
Additionally, dataset authors care about the discoverability and ease of use of their datasets. Adopting Croissant improves the value of their datasets, while only requiring a minimal effort, thanks to the available creation tools and support from ML data platforms.
&lt;/p&gt;



&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;What can Croissant do today?&lt;/h2&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgN40ZSjgTFRIVwAwN2OXIn4vQhmshC8VhcKx-ijY-sCQBH9qDkV3nrFz_YapZ0iAD-Svkyxblt6lpJFFHa4JfDqfY6RIL0RnVhtgBlLyh-1DnH8DUz7-TUSdSUIg5V2piqjmQ5Dw9MISeeSBvnMsie8jRrXOeHXfcTGQi0AHIeOYFuHYwDFSyRmBT8BHum/s908/image1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="540" data-original-width="908" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgN40ZSjgTFRIVwAwN2OXIn4vQhmshC8VhcKx-ijY-sCQBH9qDkV3nrFz_YapZ0iAD-Svkyxblt6lpJFFHa4JfDqfY6RIL0RnVhtgBlLyh-1DnH8DUz7-TUSdSUIg5V2piqjmQ5Dw9MISeeSBvnMsie8jRrXOeHXfcTGQi0AHIeOYFuHYwDFSyRmBT8BHum/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;The Croissant ecosystem: Users can Search for Croissant datasets, download them from major repositories, and easily load them into their favorite ML frameworks. They can create, inspect and modify Croissant metadata using the Croissant editor.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
Today, users can find Croissant datasets at:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;Google &lt;a href="https://datasetsearch.research.google.com/"&gt;Dataset Search&lt;/a&gt;, which offers a Croissant filter.

&lt;/li&gt;&lt;li&gt;&lt;a href="https://huggingface.co/datasets?other=croissant&amp;amp;sort=trending"&gt;HuggingFace&lt;/a&gt;

&lt;/li&gt;&lt;li&gt;&lt;a href="http://kaggle.com/datasets"&gt;Kaggle&lt;/a&gt;

&lt;/li&gt;&lt;li&gt;&lt;a href="https://openml.org/search?type=data"&gt;OpenML&lt;/a&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
With a Croissant dataset, it is possible to:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;Ingest data easily via &lt;a href="https://www.tensorflow.org/datasets"&gt;TensorFlow Datasets&lt;/a&gt; for use in popular ML frameworks like &lt;a href="https://www.tensorflow.org/"&gt;TensorFlow&lt;/a&gt;, &lt;a href="https://pytorch.org/"&gt;PyTorch&lt;/a&gt;, and &lt;a href="https://github.com/google/jax"&gt;JAX&lt;/a&gt;.

&lt;/li&gt;&lt;li&gt;Inspect and modify the metadata using the &lt;a href="https://huggingface.co/spaces/MLCommons/croissant-editor"&gt;Croissant editor UI&lt;/a&gt; (&lt;a href="https://github.com/mlcommons/croissant/tree/main/editor"&gt;github&lt;/a&gt;).
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
To publish a Croissant dataset, users can:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;Use the &lt;a href="https://huggingface.co/spaces/MLCommons/croissant-editor"&gt;Croissant editor UI&lt;/a&gt; (&lt;a href="https://github.com/mlcommons/croissant/tree/main/editor"&gt;github&lt;/a&gt;) to generate a large portion of Croissant metadata automatically by analyzing the data the user provides, and to fill important metadata fields such as RAI properties.

&lt;/li&gt;&lt;li&gt;Publish the Croissant information as part of their dataset Web page to make it discoverable and reusable.

&lt;/li&gt;&lt;li&gt;Publish their data in one of the repositories that support Croissant, such as Kaggle, HuggingFace and OpenML, and automatically generate Croissant metadata.
&lt;/li&gt;
&lt;/ul&gt;



&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Future direction&lt;/h2&gt;


&lt;p&gt;
We are excited about Croissant's potential to help ML practitioners, but making this format truly useful requires the support of the community. We encourage dataset creators to consider providing Croissant metadata. We encourage platforms hosting datasets to provide Croissant files for download and embed Croissant metadata in dataset Web pages so that they can be made discoverable by dataset search engines. Tools that help users work with ML datasets, such as labeling or data analysis tools should also consider supporting Croissant datasets. Together, we can reduce the data development burden and enable a richer ecosystem of ML research and development.  
&lt;/p&gt;
&lt;p&gt;
We encourage the community to &lt;a href="http://mlcommons.org/croissant"&gt;join us&lt;/a&gt; in contributing to the effort.
&lt;/p&gt;


&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;
&lt;p&gt;
&lt;em&gt;Croissant was developed by the &lt;a href="https://datasetsearch.research.google.com/"&gt;Dataset Search&lt;/a&gt;, &lt;a href="https://www.kaggle.com/"&gt;Kaggle&lt;/a&gt; and &lt;a href="https://www.tensorflow.org/datasets"&gt;TensorFlow Datasets&lt;/a&gt; teams from Google, as part of an &lt;a href="http://mlcommons.org"&gt;MLCommons&lt;/a&gt; community working group, which also includes contributors from these organizations: Bayer, cTuning Foundation, DANS-KNAW, Dotphoton, Harvard, Hugging Face, Kings College London, LIST, Meta, NASA, North Carolina State University, Open Data Institute, Open University of Catalonia, Sage Bionetworks, and TU Eindhoven.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/8393293208018757284/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/croissant-metadata-format-for-ml-ready.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/8393293208018757284" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/8393293208018757284" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/croissant-metadata-format-for-ml-ready.html" rel="alternate" title="Croissant: a metadata format for ML-ready datasets" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj09uSTHgWmPgOkD9W1nZZj5i8uW_-pgxm-T1O5PSacF-EKvHIeIwhMr7Rgft7O3A2Rk94GWe8WboO3dUlxrqt1xz9x4I2aMKJxCUtUkR2eukbsIa8xVyAAN_LJJyMABxRqJuktFkyfhoWPDMQK3O-XgbQNJXzAILlWl3su0fd-Q_uZ-8r5r_uAU2P4srnP/s72-c/CroissantHero.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-2754526782497247497</id><published>2024-03-04T07:06:00.000-08:00</published><updated>2024-03-05T08:40:45.490-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="conference"/><category scheme="http://www.blogger.com/atom/ns#" term="conferences"/><category scheme="http://www.blogger.com/atom/ns#" term="Physics"/><category scheme="http://www.blogger.com/atom/ns#" term="Quantum AI"/><category scheme="http://www.blogger.com/atom/ns#" term="Quantum Computing"/><title type="text">Google at APS 2024</title><content type="html">&lt;span class="byline-author"&gt;Posted by Kate Weber and Shannon Leon, Google Research, Quantum AI Team&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjy22Hfq3RN4qRUJcSMUpIau4ueOIcQ219mDvfu4FNJ9kf5PBMUI0x4Uf9BhoIHtnFUhtvE72GCVYixldOZRSeePJfef0P87Pc_djQeGIZOhyxv9nKsQCc57357tr3npWdS5fyWxiGjex4NxMpOIB2JE1Z2qXdLnzLkFM075WstFJD77xVNS2T9hckWZyLf/s1600/lockup_GoogleResearch_FullColor_Hero.jpg" style="display: none;" /&gt;

&lt;p&gt;
Today the &lt;a href="https://www.aps.org/meetings/meeting.cfm?name=MAR24"&gt;2024 March Meeting&lt;/a&gt; of the &lt;a href="https://www.aps.org/"&gt;American Physical Society&lt;/a&gt; (APS) kicks off in Minneapolis, MN. A premier conference on topics ranging across physics and related fields, APS 2024 brings together researchers, students, and industry professionals to share their discoveries and build partnerships with the goal of realizing fundamental advances in physics-related sciences and technology. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
This year, Google has a strong presence at APS with a booth hosted by the Google &lt;a href="https://quantumai.google/"&gt;Quantum AI&lt;/a&gt; team, 50+ talks throughout the conference, and participation in conference organizing activities, special sessions and events. Attending APS 2024 in person? Come visit Google’s Quantum AI booth to learn more about the exciting work we’re doing to solve some of the field’s most interesting challenges. &lt;!--Visit the &lt;a href="https://twitter.com/GoogleAI"&gt;@GoogleAI&lt;/a&gt; X (Twitter) account to find out about Google booth activities (e.g., demos and Q&amp;amp;A sessions).--&gt;
&lt;/p&gt;
&lt;p&gt;
You can learn more about the latest cutting edge work we are presenting at the conference along with our schedule of booth events below (Googlers listed in &lt;strong&gt;bold&lt;/strong&gt;).
&lt;/p&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;

&lt;h2&gt;Organizing Committee&lt;/h2&gt;
&lt;div style="margin-left: 20px;"&gt;

&lt;p&gt;

    Session Chairs include: &lt;strong&gt;Aaron Szasz&lt;/strong&gt;
&lt;/p&gt;
&lt;/div&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;

&lt;h2&gt;Booth Activities&lt;/h2&gt;
&lt;div style="margin-left: 20px;"&gt;

&lt;p&gt;

    &lt;em&gt;This schedule is subject to change. Please visit the Google Quantum AI booth for more information.&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    Crumble: A prototype interactive tool for visualizing QEC circuits
&lt;br /&gt;
  Presenter: &lt;strong&gt;Matt McEwen&lt;/strong&gt;
&lt;br /&gt;
    Tue, Mar 5 | 11:00 AM CST
&lt;/p&gt;
&lt;p&gt;

    Qualtran: An open-source library for effective resource estimation of fault tolerant algorithms
&lt;br /&gt;

    Presenter: &lt;strong&gt;Tanuj Khattar&lt;/strong&gt;
&lt;br /&gt;

    Tue, Mar 5 | 2:30 PM CST
&lt;/p&gt;
&lt;p&gt;

    Qualtran: An open-source library for effective resource estimation of fault tolerant algorithms
&lt;br /&gt;

    Presenter: &lt;strong&gt;Tanuj Khattar&lt;/strong&gt;
&lt;br /&gt;

    Thu, Mar 7 | 11:00 AM CST
&lt;/p&gt;
&lt;p&gt;

    $5M XPRIZE / Google Quantum AI competition to accelerate quantum applications Q&amp;amp;A 
&lt;br /&gt;
    Presenter: &lt;strong&gt;Ryan Babbush&lt;/strong&gt;
&lt;br /&gt;
    Thu, Mar 7 | 11:00 AM CST
&lt;/p&gt;

&lt;/div&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;


&lt;h2&gt;Talks&lt;/h2&gt;


&lt;h3&gt;Monday&lt;/h3&gt;

&lt;div style="margin-left: 20px;"&gt;

&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/A45.1"&gt;Certifying highly-entangled states from few single-qubit measurements&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Hsin-Yuan Huang&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Hsin-Yuan Huang&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session A45: New Frontiers in Machine Learning Quantum Physics&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/A51.2"&gt;Toward high-fidelity analog quantum simulation with superconducting qubits&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Trond Andersen&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Trond I Andersen&lt;/strong&gt;, &lt;strong&gt;Xiao Mi&lt;/strong&gt;, &lt;strong&gt;Amir H Karamlou&lt;/strong&gt;, &lt;strong&gt;Nikita Astrakhantsev&lt;/strong&gt;, &lt;strong&gt;Andrey Klots&lt;/strong&gt;, &lt;strong&gt;Julia Berndtsson&lt;/strong&gt;, &lt;strong&gt;Andre Petukhov&lt;/strong&gt;, &lt;strong&gt;Dmitry Abanin&lt;/strong&gt;, &lt;strong&gt;Lev B Ioffe&lt;/strong&gt;, &lt;strong&gt;Yu Chen&lt;/strong&gt;, &lt;strong&gt;Vadim Smelyanskiy&lt;/strong&gt;, &lt;strong&gt;Pedram Roushan&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session A51: Applications on Noisy Quantum Hardware I&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/B50.6"&gt;Measuring circuit errors in context for surface code circuits&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Dripto M Debroy&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Dripto M Debroy&lt;/strong&gt;, &lt;strong&gt;Jonathan A Gross&lt;/strong&gt;, &lt;strong&gt;Élie Genois&lt;/strong&gt;, &lt;strong&gt;Zhang Jiang&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session B50: Characterizing Noise with QCVV Techniques&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/B51.6"&gt;Quantum computation of stopping power for inertial fusion target design I: Physics overview and the limits of classical algorithms&lt;/a&gt;
&lt;br /&gt;
    Presenter: Andrew D. Baczewski
&lt;br /&gt;
    Authors: &lt;strong&gt;Nicholas C. Rubin&lt;/strong&gt;, Dominic W. Berry, Alina Kononov, &lt;strong&gt;Fionn D. Malone&lt;/strong&gt;, &lt;strong&gt;Tanuj Khattar&lt;/strong&gt;, Alec White, &lt;strong&gt;Joonho Lee&lt;/strong&gt;, &lt;strong&gt;Hartmut Neven&lt;/strong&gt;, &lt;strong&gt;Ryan Babbush&lt;/strong&gt;, Andrew D. Baczewski
&lt;br /&gt;
    &lt;em&gt;Session B51: Heterogeneous Design for Quantum Applications&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2308.12352.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/B51.7"&gt;Quantum computation of stopping power for inertial fusion target design II: Physics overview and the limits of classical algorithms&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Nicholas C. Rubin&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Nicholas C. Rubin&lt;/strong&gt;, Dominic W. Berry, Alina Kononov, &lt;strong&gt;Fionn D. Malone&lt;/strong&gt;, &lt;strong&gt;Tanuj Khattar&lt;/strong&gt;, Alec White, &lt;strong&gt;Joonho Lee&lt;/strong&gt;, &lt;strong&gt;Hartmut Neven&lt;/strong&gt;, &lt;strong&gt;Ryan Babbush&lt;/strong&gt;, Andrew D. Baczewski
&lt;br /&gt;
    &lt;em&gt;Session B51: Heterogeneous Design for Quantum Applications&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2308.12352.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/B56.4"&gt;Calibrating Superconducting Qubits: From NISQ to Fault Tolerance&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Sabrina S Hong&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Sabrina S Hong&lt;/strong&gt;
  &lt;br /&gt;
  &lt;em&gt;Session B56: From NISQ to Fault Tolerance&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/B31.9"&gt;Measurement and feedforward induced entanglement negativity transition&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Ramis Movassagh&lt;/strong&gt;
&lt;br /&gt;
    Authors: Alireza Seif, Yu-Xin Wang,&lt;strong&gt; Ramis Movassagh&lt;/strong&gt;, Aashish A. Clerk
&lt;br /&gt;
    &lt;em&gt;Session B31: Measurement Induced Criticality in Many-Body Systems&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2310.18305.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/B52.9"&gt;Effective quantum volume, fidelity and computational cost of noisy quantum processing experiments&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Salvatore Mandra&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Kostyantyn Kechedzhi&lt;/strong&gt;, &lt;strong&gt;Sergei V Isakov&lt;/strong&gt;, &lt;strong&gt;Salvatore Mandra&lt;/strong&gt;, &lt;strong&gt;Benjamin Villalonga&lt;/strong&gt;, &lt;strong&gt;X. Mi&lt;/strong&gt;, &lt;strong&gt;Sergio Boixo&lt;/strong&gt;, &lt;strong&gt;Vadim Smelyanskiy&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session B52: Quantum Algorithms and Complexity&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2306.15970.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/D60.4"&gt;Accurate thermodynamic tables for solids using Machine Learning Interaction Potentials and Covariance of Atomic Positions&lt;/a&gt;
&lt;br /&gt;
    Presenter: Mgcini K Phuthi
&lt;br /&gt;
    Authors: Mgcini K Phuthi, Yang Huang, Michael Widom, &lt;strong&gt;Ekin D Cubuk&lt;/strong&gt;, Venkat Viswanathan
&lt;br /&gt;
    &lt;em&gt;Session D60: Machine Learning of Molecules and Materials: Chemical Space and Dynamics&lt;/em&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;


&lt;h3&gt;Tuesday&lt;/h3&gt;
&lt;div style="margin-left: 20px;"&gt;

&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/F50.4"&gt;IN-Situ Pulse Envelope Characterization Technique (INSPECT)&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Zhang Jiang&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Zhang Jiang&lt;/strong&gt;, &lt;strong&gt;Jonathan A Gross&lt;/strong&gt;, &lt;strong&gt;Élie Genois&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session F50: Advanced Randomized Benchmarking and Gate Calibration&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/F50.11"&gt;Characterizing two-qubit gates with dynamical decoupling&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Jonathan A Gross&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Jonathan A Gross&lt;/strong&gt;, &lt;strong&gt;Zhang Jiang&lt;/strong&gt;, &lt;strong&gt;Élie Genois, Dripto M Debroy&lt;/strong&gt;, Ze-Pei Cian*, &lt;strong&gt;Wojciech Mruczkiewicz&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session F50: Advanced Randomized Benchmarking and Gate Calibration&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/EE01.2"&gt;Statistical physics of regression with quadratic models&lt;/a&gt;
&lt;br /&gt;
    Presenter: Blake Bordelon
&lt;br /&gt;
    Authors: Blake Bordelon, Cengiz Pehlevan, &lt;strong&gt;Yasaman Bahri&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session EE01: V: Statistical and Nonlinear Physics II&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/G51.2"&gt;Improved state preparation for first-quantized simulation of electronic structure&lt;/a&gt;
&lt;br /&gt; 
  Presenter: &lt;strong&gt;William J Huggins&lt;/strong&gt;
&lt;br /&gt; 
  Authors: &lt;strong&gt;William J Huggins&lt;/strong&gt;, &lt;strong&gt;Oskar Leimkuhler&lt;/strong&gt;, &lt;strong&gt;Torin F Stetina&lt;/strong&gt;, &lt;strong&gt;Birgitta Whaley&lt;/strong&gt;
&lt;br /&gt; 
  &lt;em&gt;Session G51: Hamiltonian Simulation&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/G30.2"&gt;Controlling large superconducting quantum processors&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Paul V. Klimov&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Paul V. Klimov&lt;/strong&gt;, &lt;strong&gt;Andreas Bengtsson&lt;/strong&gt;, &lt;strong&gt;Chris Quintana&lt;/strong&gt;, &lt;strong&gt;Alexandre Bourassa&lt;/strong&gt;, &lt;strong&gt;Sabrina Hong&lt;/strong&gt;, &lt;strong&gt;Andrew Dunsworth&lt;/strong&gt;, &lt;strong&gt;Kevin J. Satzinger&lt;/strong&gt;, &lt;strong&gt;William P. Livingston&lt;/strong&gt;, &lt;strong&gt;Volodymyr Sivak&lt;/strong&gt;, &lt;strong&gt;Murphy Y. Niu&lt;/strong&gt;, &lt;strong&gt;Trond I. Andersen&lt;/strong&gt;, &lt;strong&gt;Yaxing Zhang&lt;/strong&gt;, &lt;strong&gt;Desmond Chik&lt;/strong&gt;, &lt;strong&gt;Zijun Chen&lt;/strong&gt;, &lt;strong&gt;Charles Neill&lt;/strong&gt;, &lt;strong&gt;Catherine Erickson&lt;/strong&gt;, &lt;strong&gt;Alejandro Grajales Dau&lt;/strong&gt;, &lt;strong&gt;Anthony Megrant&lt;/strong&gt;, &lt;strong&gt;Pedram Roushan&lt;/strong&gt;, &lt;strong&gt;Alexander N. Korotkov&lt;/strong&gt;, &lt;strong&gt;Julian Kelly&lt;/strong&gt;, &lt;strong&gt;Vadim Smelyanskiy&lt;/strong&gt;, &lt;strong&gt;Yu Chen&lt;/strong&gt;, &lt;strong&gt;Hartmut Neven&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session G30: Commercial Applications of Quantum Computing&lt;/em&gt;&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2308.02321.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/G50.5"&gt;Gaussian boson sampling: Determining quantum advantage&lt;/a&gt;
&lt;br /&gt;
    Presenter: Peter D Drummond
&lt;br /&gt;
    Authors: Peter D Drummond, Alex Dellios, Ned Goodman, Margaret D Reid, &lt;strong&gt;Ben Villalonga&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session G50: Quantum Characterization, Verification, and Validation II&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/G50.8"&gt;Attention to complexity III: learning the complexity of random quantum circuit states&lt;/a&gt;
&lt;br /&gt;
    Presenter: Hyejin Kim
&lt;br /&gt;
    Authors: Hyejin Kim, Yiqing Zhou, Yichen Xu, Chao Wan, Jin Zhou, &lt;strong&gt;Yuri D Lensky&lt;/strong&gt;, Jesse Hoke, &lt;strong&gt;Pedram Roushan&lt;/strong&gt;, Kilian Q Weinberger, Eun-Ah Kim
&lt;br /&gt;
    &lt;em&gt;Session G50: Quantum Characterization, Verification, and Validation II&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/K48.10"&gt;Balanced coupling in superconducting circuits&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Daniel T Sank&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Daniel T Sank&lt;/strong&gt;, &lt;strong&gt;Sergei V Isakov&lt;/strong&gt;, &lt;strong&gt;Mostafa Khezri&lt;/strong&gt;, &lt;strong&gt;Juan Atalaya&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session K48: Strongly Driven Superconducting Systems&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/K49.12"&gt;Resource estimation of Fault Tolerant algorithms using Qᴜᴀʟᴛʀᴀɴ&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Tanuj Khattar&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Tanuj Khattar&lt;/strong&gt;, &lt;b&gt;Matthew Harrigan&lt;/b&gt;, &lt;b&gt;Fionn D. Malone&lt;/b&gt;, &lt;b&gt;Nour Yosri&lt;/b&gt;, &lt;b&gt;Nicholas C. Rubin&lt;/b&gt;&lt;br /&gt;
    &lt;em&gt;Session K49: Algorithms and Implementations on Near-Term Quantum Computers&lt;/em&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;


&lt;h3&gt;Wednesday&lt;/h3&gt;
&lt;div style="margin-left: 20px;"&gt;

&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/M24.1"&gt;Discovering novel quantum dynamics with superconducting qubits&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Pedram Roushan&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Pedram Roushan&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session M24: Analog Quantum Simulations Across Platforms&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/M27.7"&gt;Deciphering Tumor Heterogeneity in Triple-Negative Breast Cancer: The Crucial Role of Dynamic Cell-Cell and Cell-Matrix Interactions&lt;/a&gt;
&lt;br /&gt;
    Presenter: Susan Leggett
&lt;br /&gt;
    Authors: Susan Leggett, Ian Wong, Celeste Nelson, Molly Brennan, &lt;strong&gt;Mohak Patel&lt;/strong&gt;, Christian Franck, Sophia Martinez, Joe Tien, Lena Gamboa, Thomas Valentin, Amanda Khoo, Evelyn K Williams 
&lt;br /&gt;
    &lt;em&gt;Session M27: Mechanics of Cells and Tissues II&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/N48.2"&gt;Toward implementation of protected charge-parity qubits&lt;/a&gt;
&lt;br /&gt;
    Presenter: Abigail Shearrow
&lt;br /&gt;
    Authors: Abigail Shearrow, Matthew Snyder, Bradley G Cole, Kenneth R Dodge, Yebin Liu, Andrey Klots, &lt;strong&gt;Lev B Ioffe&lt;/strong&gt;, Britton L Plourde, Robert McDermott
&lt;br /&gt;
    &lt;em&gt;Session N48: Unconventional Superconducting Qubits&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/N48.3"&gt;Electronic capacitance in tunnel junctions for protected charge-parity qubits&lt;/a&gt;
&lt;br /&gt;
    Presenter: Bradley G Cole
&lt;br /&gt;
    Authors: Bradley G Cole, Kenneth R Dodge, Yebin Liu, Abigail Shearrow, Matthew Snyder, &lt;strong&gt;Andrey Klots&lt;/strong&gt;, &lt;strong&gt;Lev B Ioffe&lt;/strong&gt;, Robert McDermott, B.L.T. Plourde
&lt;br /&gt;
    &lt;em&gt;Session N48: Unconventional Superconducting Qubits&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/N51.7"&gt;Overcoming leakage in quantum error correction&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Kevin C. Miao&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Kevin C. Miao&lt;/strong&gt;, &lt;strong&gt;Matt McEwen&lt;/strong&gt;, &lt;strong&gt;Juan Atalaya&lt;/strong&gt;, &lt;strong&gt;Dvir Kafri&lt;/strong&gt;, &lt;strong&gt;Leonid P. Pryadko&lt;/strong&gt;, &lt;strong&gt;Andreas Bengtsson&lt;/strong&gt;, &lt;strong&gt;Alex Opremcak&lt;/strong&gt;, &lt;strong&gt;Kevin J. Satzinger&lt;/strong&gt;, &lt;strong&gt;Zijun Chen&lt;/strong&gt;, &lt;strong&gt;Paul V. Klimov&lt;/strong&gt;, &lt;strong&gt;Chris Quintana&lt;/strong&gt;, &lt;strong&gt;Rajeev Acharya&lt;/strong&gt;, &lt;strong&gt;Kyle Anderson&lt;/strong&gt;, &lt;strong&gt;Markus Ansmann&lt;/strong&gt;, &lt;strong&gt;Frank Arute&lt;/strong&gt;, &lt;strong&gt;Kunal Arya&lt;/strong&gt;, &lt;strong&gt;Abraham Asfaw&lt;/strong&gt;, &lt;strong&gt;Joseph C. Bardin&lt;/strong&gt;, &lt;strong&gt;Alexandre Bourassa&lt;/strong&gt;, &lt;strong&gt;Jenna Bovaird&lt;/strong&gt;, &lt;strong&gt;Leon Brill&lt;/strong&gt;, &lt;strong&gt;Bob B. Buckley&lt;/strong&gt;, &lt;strong&gt;David A. Buell&lt;/strong&gt;, &lt;strong&gt;Tim Burger&lt;/strong&gt;, &lt;strong&gt;Brian Burkett&lt;/strong&gt;, &lt;strong&gt;Nicholas Bushnell&lt;/strong&gt;, &lt;strong&gt;Juan Campero&lt;/strong&gt;, &lt;strong&gt;Ben Chiaro&lt;/strong&gt;, &lt;strong&gt;Roberto Collins&lt;/strong&gt;, &lt;strong&gt;Paul Conner&lt;/strong&gt;, &lt;strong&gt;Alexander L. Crook&lt;/strong&gt;, &lt;strong&gt;Ben Curtin&lt;/strong&gt;, &lt;strong&gt;Dripto M. Debroy&lt;/strong&gt;, &lt;strong&gt;Sean Demura&lt;/strong&gt;, &lt;strong&gt;Andrew Dunsworth&lt;/strong&gt;, &lt;strong&gt;Catherine Erickson&lt;/strong&gt;, &lt;strong&gt;Reza Fatemi&lt;/strong&gt;, &lt;strong&gt;Vinicius S. Ferreira&lt;/strong&gt;, &lt;strong&gt;Leslie Flores Burgos&lt;/strong&gt;, &lt;strong&gt;Ebrahim Forati&lt;/strong&gt;, &lt;strong&gt;Austin G. Fowler&lt;/strong&gt;, &lt;strong&gt;Brooks Foxen&lt;/strong&gt;, &lt;strong&gt;Gonzalo Garcia&lt;/strong&gt;, &lt;strong&gt;William Giang&lt;/strong&gt;, &lt;strong&gt;Craig Gidney&lt;/strong&gt;, &lt;strong&gt;Marissa Giustina&lt;/strong&gt;, &lt;strong&gt;Raja Gosula&lt;/strong&gt;, &lt;strong&gt;Alejandro Grajales Dau&lt;/strong&gt;, &lt;strong&gt;Jonathan A. Gross&lt;/strong&gt;, &lt;strong&gt;Michael C. Hamilton&lt;/strong&gt;, &lt;strong&gt;Sean D. Harrington&lt;/strong&gt;, &lt;strong&gt;Paula Heu&lt;/strong&gt;, &lt;strong&gt;Jeremy Hilton&lt;/strong&gt;, &lt;strong&gt;Markus R. Hoffmann&lt;/strong&gt;, &lt;strong&gt;Sabrina Hong&lt;/strong&gt;, &lt;strong&gt;Trent Huang&lt;/strong&gt;, &lt;strong&gt;Ashley Huff&lt;/strong&gt;, &lt;strong&gt;Justin Iveland&lt;/strong&gt;, &lt;strong&gt;Evan Jeffrey&lt;/strong&gt;, &lt;strong&gt;Zhang Jiang&lt;/strong&gt;, &lt;strong&gt;Cody Jones&lt;/strong&gt;, &lt;strong&gt;Julian Kelly&lt;/strong&gt;, &lt;strong&gt;Seon Kim&lt;/strong&gt;, &lt;strong&gt;Fedor Kostritsa&lt;/strong&gt;, &lt;strong&gt;John Mark Kreikebaum&lt;/strong&gt;, &lt;strong&gt;David Landhuis&lt;/strong&gt;, &lt;strong&gt;Pavel Laptev&lt;/strong&gt;, &lt;strong&gt;Lily Laws&lt;/strong&gt;, &lt;strong&gt;Kenny Lee&lt;/strong&gt;, &lt;strong&gt;Brian J. Lester&lt;/strong&gt;, &lt;strong&gt;Alexander T. Lill&lt;/strong&gt;, &lt;strong&gt;Wayne Liu&lt;/strong&gt;, &lt;strong&gt;Aditya Locharla&lt;/strong&gt;, &lt;strong&gt;Erik Lucero&lt;/strong&gt;, &lt;strong&gt;Steven Martin&lt;/strong&gt;, &lt;strong&gt;Anthony Megrant&lt;/strong&gt;, &lt;strong&gt;Xiao Mi&lt;/strong&gt;, &lt;strong&gt;Shirin Montazeri&lt;/strong&gt;, &lt;strong&gt;Alexis Morvan&lt;/strong&gt;, &lt;strong&gt;Ofer Naaman&lt;/strong&gt;, &lt;strong&gt;Matthew Neeley&lt;/strong&gt;, &lt;strong&gt;Charles Neill&lt;/strong&gt;, &lt;strong&gt;Ani Nersisyan&lt;/strong&gt;, &lt;strong&gt;Michael Newman&lt;/strong&gt;, &lt;strong&gt;Jiun How Ng&lt;/strong&gt;, &lt;strong&gt;Anthony Nguyen&lt;/strong&gt;, &lt;strong&gt;Murray Nguyen&lt;/strong&gt;, &lt;strong&gt;Rebecca Potter&lt;/strong&gt;, &lt;strong&gt;Charles Rocque&lt;/strong&gt;, &lt;strong&gt;Pedram Roushan&lt;/strong&gt;, &lt;strong&gt;Kannan Sankaragomathi&lt;/strong&gt;, &lt;strong&gt;Christopher Schuster&lt;/strong&gt;, &lt;strong&gt;Michael J. Shearn&lt;/strong&gt;, &lt;strong&gt;Aaron Shorter&lt;/strong&gt;, &lt;strong&gt;Noah Shutty&lt;/strong&gt;, &lt;strong&gt;Vladimir Shvarts&lt;/strong&gt;, &lt;strong&gt;Jindra Skruzny&lt;/strong&gt;, &lt;strong&gt;W. Clarke Smith&lt;/strong&gt;, &lt;strong&gt;George Sterling&lt;/strong&gt;, &lt;strong&gt;Marco Szalay&lt;/strong&gt;, &lt;strong&gt;Douglas Thor&lt;/strong&gt;, &lt;strong&gt;Alfredo Torres&lt;/strong&gt;, &lt;strong&gt;Theodore White&lt;/strong&gt;, &lt;strong&gt;Bryan W. K. Woo&lt;/strong&gt;, &lt;strong&gt;Z. Jamie Yao&lt;/strong&gt;, &lt;strong&gt;Ping Yeh&lt;/strong&gt;, &lt;strong&gt;Juhwan Yoo&lt;/strong&gt;, &lt;strong&gt;Grayson Young&lt;/strong&gt;, &lt;strong&gt;Adam Zalcman&lt;/strong&gt;, &lt;strong&gt;Ningfeng Zhu&lt;/strong&gt;, &lt;strong&gt;Nicholas Zobrist&lt;/strong&gt;, &lt;strong&gt;Hartmut Neven&lt;/strong&gt;, &lt;strong&gt;Vadim Smelyanskiy&lt;/strong&gt;, &lt;strong&gt;Andre Petukhov&lt;/strong&gt;, &lt;strong&gt;Alexander N. Korotkov&lt;/strong&gt;, &lt;strong&gt;Daniel Sank&lt;/strong&gt;, &lt;strong&gt;Yu Chen&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session N51: Quantum Error Correction Code Performance and Implementation I&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://www.nature.com/articles/s41567-023-02226-w"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/N51.11"&gt;Modeling the performance of the surface code with non-uniform error distribution: Part 1&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Yuri D Lensky&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Yuri D Lensky&lt;/strong&gt;, &lt;strong&gt;Volodymyr Sivak&lt;/strong&gt;, &lt;strong&gt;Kostyantyn Kechedzhi&lt;/strong&gt;, &lt;strong&gt;Igor Aleiner&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session N51: Quantum Error Correction Code Performance and Implementation I&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/N51.12"&gt;Modeling the performance of the surface code with non-uniform error distribution: Part 2&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Volodymyr Sivak&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Volodymyr Sivak&lt;/strong&gt;, &lt;strong&gt;Michael Newman&lt;/strong&gt;, &lt;strong&gt;Cody Jones&lt;/strong&gt;, &lt;strong&gt;Henry Schurkus&lt;/strong&gt;, &lt;strong&gt;Dvir Kafri&lt;/strong&gt;, &lt;strong&gt;Yuri D Lensky&lt;/strong&gt;, &lt;strong&gt;Paul Klimov&lt;/strong&gt;, &lt;strong&gt;Kostyantyn Kechedzhi&lt;/strong&gt;, &lt;strong&gt;Vadim Smelyanskiy&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session N51: Quantum Error Correction Code Performance and Implementation I&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/Q51.7"&gt;Highly optimized tensor network contractions for the simulation of classically challenging quantum computations&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Benjamin Villalonga&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Benjamin Villalonga&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session Q51: Co-evolution of Quantum Classical Algorithms&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/Q61.7"&gt;Teaching modern quantum computing concepts using hands-on open-source software at all levels&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Abraham Asfaw&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Abraham Asfaw&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session Q61: Teaching Quantum Information at All Levels II&lt;/em&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
  
&lt;h3&gt;Thursday&lt;/h3&gt;
&lt;div style="margin-left: 20px;"&gt;

&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/S51.1"&gt;New circuits and an open source decoder for the color code&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Craig Gidney&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Craig Gidney&lt;/strong&gt;, &lt;strong&gt;Cody Jones&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session S51: Quantum Error Correction Code Performance and Implementation II&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2312.08813.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/S18.2"&gt;Performing Hartree-Fock many-body physics calculations with large language models&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Eun-Ah Kim&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Eun-Ah Kim&lt;/strong&gt;, Haining Pan, &lt;strong&gt;Nayantara Mudur&lt;/strong&gt;, William Taranto,&lt;strong&gt; Subhashini Venugopalan&lt;/strong&gt;, &lt;strong&gt;Yasaman Bahri&lt;/strong&gt;, &lt;strong&gt;Michael P Brenner&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session S18: Data Science, AI and Machine Learning in Physics I&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/S51.5"&gt;New methods for reducing resource overhead in the surface code&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Michael Newman&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Craig M Gidney&lt;/strong&gt;, &lt;strong&gt;Michael Newman&lt;/strong&gt;, &lt;strong&gt;Peter Brooks&lt;/strong&gt;, &lt;strong&gt;Cody Jones&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session S51: Quantum Error Correction Code Performance and Implementation II&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2312.04522.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/S49.10"&gt;Challenges and opportunities for applying quantum computers to drug design&lt;/a&gt;
&lt;br /&gt;
    Presenter: Raffaele Santagati
&lt;br /&gt;
    Authors: Raffaele Santagati, Alan Aspuru-Guzik, &lt;strong&gt;Ryan Babbush&lt;/strong&gt;, Matthias Degroote, Leticia Gonzalez, Elica Kyoseva, Nikolaj Moll, Markus Oppel, Robert M. Parrish, &lt;strong&gt;Nicholas C. Rubin&lt;/strong&gt;, Michael Streif, Christofer S. Tautermann, Horst Weiss, Nathan Wiebe, Clemens Utschig-Utschig
&lt;br /&gt;
    &lt;em&gt;Session S49: Advances in Quantum Algorithms for Near-Term Applications&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2301.04114.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/T45.1"&gt;Dispatches from Google's hunt for super-quadratic quantum advantage in new applications&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Ryan Babbush&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Ryan Babbush&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session T45: Recent Advances in Quantum Algorithms&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/T48.11"&gt;Qubit as a reflectometer&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Yaxing Zhang&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Yaxing Zhang&lt;/strong&gt;, &lt;strong&gt;Benjamin Chiaro&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session T48: Superconducting Fabrication, Packaging, &amp;amp; Validation&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/W14.3"&gt;Random-matrix theory of measurement-induced phase transitions in nonlocal Floquet quantum circuits&lt;/a&gt;
&lt;br /&gt;
    Presenter: Aleksei Khindanov
&lt;br /&gt;
    Authors: Aleksei Khindanov, &lt;strong&gt;Lara Faoro&lt;/strong&gt;, &lt;strong&gt;Lev Ioffe&lt;/strong&gt;, &lt;strong&gt;Igor Aleiner&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session W14: Measurement-Induced Phase Transitions&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/W58.5"&gt;Continuum limit of finite density many-body ground states with MERA&lt;/a&gt;
&lt;br /&gt;
    Presenter: Subhayan Sahu
&lt;br /&gt;
    Authors: Subhayan Sahu, &lt;strong&gt;Guifré Vidal&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session W58: Extreme-Scale Computational Science Discovery in Fluid Dynamics and Related Disciplines II&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/W50.8"&gt;Dynamics of magnetization at infinite temperature in a Heisenberg spin chain&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Eliott Rosenberg&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Eliott Rosenberg&lt;/strong&gt;, &lt;strong&gt;Trond Andersen&lt;/strong&gt;, Rhine Samajdar, &lt;strong&gt;Andre Petukhov&lt;/strong&gt;, Jesse Hoke*,&lt;strong&gt; Dmitry Abanin&lt;/strong&gt;, &lt;strong&gt;Andreas Bengtsson&lt;/strong&gt;, &lt;strong&gt;Ilya Drozdov&lt;/strong&gt;, &lt;strong&gt;Catherine Erickson&lt;/strong&gt;,&lt;strong&gt; Paul Klimov&lt;/strong&gt;, &lt;strong&gt;Xiao Mi&lt;/strong&gt;, &lt;strong&gt;Alexis Morvan&lt;/strong&gt;, &lt;strong&gt;Matthew Neeley&lt;/strong&gt;, &lt;strong&gt;Charles Neill&lt;/strong&gt;, &lt;strong&gt;Rajeev Acharya&lt;/strong&gt;, &lt;strong&gt;Richard Allen&lt;/strong&gt;, &lt;strong&gt;Kyle Anderson&lt;/strong&gt;, &lt;strong&gt;Markus Ansmann&lt;/strong&gt;, &lt;strong&gt;Frank Arute&lt;/strong&gt;, &lt;strong&gt;Kunal Arya&lt;/strong&gt;, &lt;strong&gt;Abraham Asfaw&lt;/strong&gt;, &lt;strong&gt;Juan Atalaya&lt;/strong&gt;, &lt;strong&gt;Joseph Bardin&lt;/strong&gt;, &lt;strong&gt;A. Bilmes&lt;/strong&gt;, &lt;strong&gt;Gina Bortoli&lt;/strong&gt;, &lt;strong&gt;Alexandre Bourassa&lt;/strong&gt;, &lt;strong&gt;Jenna Bovaird&lt;/strong&gt;, &lt;strong&gt;Leon Brill&lt;/strong&gt;, &lt;strong&gt;Michael Broughton&lt;/strong&gt;, &lt;strong&gt;Bob B. Buckley&lt;/strong&gt;, &lt;strong&gt;David Buell&lt;/strong&gt;, &lt;strong&gt;Tim Burger&lt;/strong&gt;, &lt;strong&gt;Brian Burkett&lt;/strong&gt;, &lt;strong&gt;Nicholas Bushnell&lt;/strong&gt;, &lt;strong&gt;Juan Campero&lt;/strong&gt;, &lt;strong&gt;Hung-Shen Chang&lt;/strong&gt;, &lt;strong&gt;Zijun Chen&lt;/strong&gt;, &lt;strong&gt;Benjamin Chiaro&lt;/strong&gt;, &lt;strong&gt;Desmond Chik&lt;/strong&gt;, &lt;strong&gt;Josh Cogan&lt;/strong&gt;, &lt;strong&gt;Roberto Collins&lt;/strong&gt;, &lt;strong&gt;Paul Conner&lt;/strong&gt;, &lt;strong&gt;William Courtney&lt;/strong&gt;, &lt;strong&gt;Alexander Crook&lt;/strong&gt;, &lt;strong&gt;Ben Curtin&lt;/strong&gt;, &lt;strong&gt;Dripto Debroy&lt;/strong&gt;, &lt;strong&gt;Alexander Del Toro Barba&lt;/strong&gt;, &lt;strong&gt;Sean Demura&lt;/strong&gt;, &lt;strong&gt;Agustin Di Paolo&lt;/strong&gt;, &lt;strong&gt;Andrew Dunsworth&lt;/strong&gt;, &lt;strong&gt;Clint Earle&lt;/strong&gt;, &lt;strong&gt;E. Farhi&lt;/strong&gt;, &lt;strong&gt;Reza Fatemi&lt;/strong&gt;, &lt;strong&gt;Vinicius Ferreira&lt;/strong&gt;, &lt;strong&gt;Leslie Flores&lt;/strong&gt;, &lt;strong&gt;Ebrahim Forati&lt;/strong&gt;, &lt;strong&gt;Austin Fowler&lt;/strong&gt;, &lt;strong&gt;Brooks Foxen&lt;/strong&gt;, &lt;strong&gt;Gonzalo Garcia&lt;/strong&gt;, &lt;strong&gt;Élie Genois&lt;/strong&gt;, &lt;strong&gt;William Giang&lt;/strong&gt;, &lt;strong&gt;Craig Gidney&lt;/strong&gt;, &lt;strong&gt;Dar Gilboa&lt;/strong&gt;, &lt;strong&gt;Marissa Giustina&lt;/strong&gt;, &lt;strong&gt;Raja Gosula&lt;/strong&gt;, &lt;strong&gt;Alejandro Grajales Dau&lt;/strong&gt;, &lt;strong&gt;Jonathan Gross&lt;/strong&gt;, &lt;strong&gt;Steve Habegger&lt;/strong&gt;, &lt;strong&gt;Michael Hamilton&lt;/strong&gt;, &lt;strong&gt;Monica Hansen&lt;/strong&gt;, &lt;strong&gt;Matthew Harrigan&lt;/strong&gt;, &lt;strong&gt;Sean Harrington&lt;/strong&gt;, &lt;strong&gt;Paula Heu&lt;/strong&gt;, &lt;strong&gt;Gordon Hill&lt;/strong&gt;, &lt;strong&gt;Markus Hoffmann&lt;/strong&gt;, &lt;strong&gt;Sabrina Hong&lt;/strong&gt;, &lt;strong&gt;Trent Huang&lt;/strong&gt;, &lt;strong&gt;Ashley Huff&lt;/strong&gt;, &lt;strong&gt;William Huggins&lt;/strong&gt;, &lt;strong&gt;Lev Ioffe&lt;/strong&gt;, &lt;strong&gt;Sergei Isakov&lt;/strong&gt;, &lt;strong&gt;Justin Iveland&lt;/strong&gt;, &lt;strong&gt;Evan Jeffrey&lt;/strong&gt;, &lt;strong&gt;Zhang Jiang&lt;/strong&gt;, &lt;strong&gt;Cody Jones&lt;/strong&gt;, &lt;strong&gt;Pavol Juhas&lt;/strong&gt;, &lt;strong&gt;D. Kafri&lt;/strong&gt;, &lt;strong&gt;Tanuj Khattar&lt;/strong&gt;, &lt;strong&gt;Mostafa Khezri&lt;/strong&gt;, &lt;strong&gt;Mária Kieferová&lt;/strong&gt;, &lt;strong&gt;Seon Kim&lt;/strong&gt;, &lt;strong&gt;Alexei Kitaev&lt;/strong&gt;, &lt;strong&gt;Andrey Klots&lt;/strong&gt;, &lt;strong&gt;Alexander Korotkov&lt;/strong&gt;, &lt;strong&gt;Fedor Kostritsa&lt;/strong&gt;, &lt;strong&gt;John Mark Kreikebaum&lt;/strong&gt;, &lt;strong&gt;David Landhuis&lt;/strong&gt;, &lt;strong&gt;Pavel Laptev&lt;/strong&gt;, &lt;strong&gt;Kim Ming Lau&lt;/strong&gt;, &lt;strong&gt;Lily Laws&lt;/strong&gt;, &lt;strong&gt;Joonho Lee&lt;/strong&gt;, &lt;strong&gt;Kenneth Lee&lt;/strong&gt;, &lt;strong&gt;Yuri Lensky&lt;/strong&gt;, &lt;strong&gt;Brian Lester&lt;/strong&gt;, &lt;strong&gt;Alexander Lill&lt;/strong&gt;, &lt;strong&gt;Wayne Liu&lt;/strong&gt;, &lt;strong&gt;William P. Livingston&lt;/strong&gt;, &lt;strong&gt;A. Locharla&lt;/strong&gt;, &lt;strong&gt;Salvatore Mandrà&lt;/strong&gt;, &lt;strong&gt;Orion Martin&lt;/strong&gt;, &lt;strong&gt;Steven Martin&lt;/strong&gt;, &lt;strong&gt;Jarrod McClean&lt;/strong&gt;, &lt;strong&gt;Matthew McEwen&lt;/strong&gt;, &lt;strong&gt;Seneca Meeks&lt;/strong&gt;, &lt;strong&gt;Kevin Miao&lt;/strong&gt;, &lt;strong&gt;Amanda Mieszala&lt;/strong&gt;, &lt;strong&gt;Shirin Montazeri&lt;/strong&gt;, &lt;strong&gt;Ramis Movassagh&lt;/strong&gt;, &lt;strong&gt;Wojciech Mruczkiewicz&lt;/strong&gt;, &lt;strong&gt;Ani Nersisyan&lt;/strong&gt;, &lt;strong&gt;Michael Newman&lt;/strong&gt;, &lt;strong&gt;Jiun How Ng&lt;/strong&gt;, &lt;strong&gt;Anthony Nguyen&lt;/strong&gt;, &lt;strong&gt;Murray Nguyen&lt;/strong&gt;, &lt;strong&gt;M. Niu&lt;/strong&gt;, &lt;strong&gt;Thomas O'Brien&lt;/strong&gt;, &lt;strong&gt;Seun Omonije&lt;/strong&gt;, &lt;strong&gt;Alex Opremcak&lt;/strong&gt;, &lt;strong&gt;Rebecca Potter&lt;/strong&gt;, &lt;strong&gt;Leonid Pryadko&lt;/strong&gt;, &lt;strong&gt;Chris Quintana&lt;/strong&gt;, &lt;strong&gt;David Rhodes&lt;/strong&gt;, &lt;strong&gt;Charles Rocque&lt;/strong&gt;, &lt;strong&gt;N. Rubin&lt;/strong&gt;, &lt;strong&gt;Negar Saei&lt;/strong&gt;, &lt;strong&gt;Daniel Sank&lt;/strong&gt;, &lt;strong&gt;Kannan Sankaragomathi&lt;/strong&gt;, &lt;strong&gt;Kevin Satzinger&lt;/strong&gt;, &lt;strong&gt;Henry Schurkus&lt;/strong&gt;, &lt;strong&gt;Christopher Schuster&lt;/strong&gt;, &lt;strong&gt;Michael Shearn&lt;/strong&gt;, &lt;strong&gt;Aaron Shorter&lt;/strong&gt;, &lt;strong&gt;Noah Shutty&lt;/strong&gt;, &lt;strong&gt;Vladimir Shvarts&lt;/strong&gt;, &lt;strong&gt;Volodymyr Sivak&lt;/strong&gt;, &lt;strong&gt;Jindra Skruzny&lt;/strong&gt;, &lt;strong&gt;Clarke Smith&lt;/strong&gt;, &lt;strong&gt;Rolando Somma&lt;/strong&gt;, &lt;strong&gt;George Sterling&lt;/strong&gt;, &lt;strong&gt;Doug Strain&lt;/strong&gt;, &lt;strong&gt;Marco Szalay&lt;/strong&gt;, &lt;strong&gt;Douglas Thor&lt;/strong&gt;, &lt;strong&gt;Alfredo Torres&lt;/strong&gt;, &lt;strong&gt;Guifre Vidal&lt;/strong&gt;, &lt;strong&gt;Benjamin Villalonga&lt;/strong&gt;, &lt;strong&gt;Catherine Vollgraff Heidweiller&lt;/strong&gt;, &lt;strong&gt;Theodore White&lt;/strong&gt;, &lt;strong&gt;Bryan Woo&lt;/strong&gt;, &lt;strong&gt;Cheng Xing&lt;/strong&gt;, &lt;strong&gt;Jamie Yao&lt;/strong&gt;, &lt;strong&gt;Ping Yeh&lt;/strong&gt;, &lt;strong&gt;Juhwan Yoo&lt;/strong&gt;, &lt;strong&gt;Grayson Young&lt;/strong&gt;, &lt;strong&gt;Adam Zalcman&lt;/strong&gt;, &lt;strong&gt;Yaxing Zhang&lt;/strong&gt;, &lt;strong&gt;Ningfeng Zhu&lt;/strong&gt;, &lt;strong&gt;Nicholas Zobrist&lt;/strong&gt;, &lt;strong&gt;Hartmut Neven&lt;/strong&gt;, &lt;strong&gt;Ryan Babbush&lt;/strong&gt;, &lt;strong&gt;Dave Bacon&lt;/strong&gt;, &lt;strong&gt;Sergio Boixo&lt;/strong&gt;, &lt;strong&gt;Jeremy Hilton&lt;/strong&gt;, &lt;strong&gt;Erik Lucero&lt;/strong&gt;, &lt;strong&gt;Anthony Megrant&lt;/strong&gt;, &lt;strong&gt;Julian Kelly&lt;/strong&gt;, &lt;strong&gt;Yu Chen&lt;/strong&gt;, &lt;strong&gt;Vadim Smelyanskiy&lt;/strong&gt;, Vedika Khemani, Sarang Gopalakrishnan,&lt;strong&gt; Tomaž Prosen&lt;/strong&gt;, &lt;strong&gt;Pedram Roushan&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session W50: Quantum Simulation of Many-Body Physics&lt;/em&gt;
&lt;br /&gt;
    &lt;a href="https://arxiv.org/pdf/2306.09333.pdf"&gt;Link to Paper&lt;/a&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/W50.13"&gt;The fast multipole method on a quantum computer&lt;/a&gt;
&lt;br /&gt;
    Presenter: Kianna Wan
&lt;br /&gt;
    Authors: Kianna Wan, Dominic W Berry, &lt;strong&gt;Ryan Babbush&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session W50: Quantum Simulation of Many-Body Physics&lt;/em&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;

&lt;h3&gt;Friday&lt;/h3&gt;
&lt;div style="margin-left: 20px;"&gt;

&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/Y43.1"&gt;The quantum computing industry and protecting national security: what tools will work?&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Kate Weber&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Kate Weber&lt;/strong&gt;
  &lt;br /&gt;
  &lt;em&gt;Session Y43: Industry, Innovation, and National Security: Finding the Right Balance&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/Y46.3"&gt;Novel charging effects in the fluxonium qubit&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Agustin Di Paolo&lt;/strong&gt;
&lt;br /&gt;
    Authors: &lt;strong&gt;Agustin Di Paolo&lt;/strong&gt;, Kyle Serniak, Andrew J Kerman, &lt;strong&gt;William D Oliver&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session Y46: Fluxonium-Based Superconducting Quibits&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/Z46.3"&gt;Microwave Engineering of Parametric Interactions in Superconducting Circuits&lt;/a&gt;
&lt;br /&gt;
    Presenter: &lt;strong&gt;Ofer Naaman&lt;/strong&gt;
&lt;br /&gt;
    Author: &lt;strong&gt;Ofer Naaman&lt;/strong&gt;
&lt;br /&gt;
    &lt;em&gt;Session Z46: Broadband Parametric Amplifiers and Circulators&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;

    &lt;a href="https://meetings.aps.org/Meeting/MAR24/Session/Z62.3"&gt;Linear spin wave theory of large magnetic unit cells using the Kernel Polynomial Method&lt;/a&gt;
&lt;br /&gt;
    Presenter: Harry Lane
&lt;br /&gt;
    Authors: Harry Lane, Hao Zhang, David A Dahlbom, Sam Quinn, &lt;strong&gt;Rolando D Somma&lt;/strong&gt;, Martin P Mourigal, Cristian D Batista, Kipton Barros
&lt;br /&gt;
    &lt;em&gt;Session Z62: Cooperative Phenomena, Theory&lt;/em&gt;
&lt;/p&gt;
&lt;/div&gt;

&lt;!--Footnotes--&gt;
&lt;hr width="80%" /&gt;
&lt;p&gt;
  &lt;span class="Apple-style-span" style="font-size: x-small;"&gt;&lt;sup&gt;&lt;b&gt;*&lt;/b&gt;&lt;/sup&gt;Work done while at Google&lt;/span&gt;&lt;/p&gt;
</content><link href="http://blog.research.google/feeds/2754526782497247497/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/google-at-aps-2024.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/2754526782497247497" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/2754526782497247497" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/03/google-at-aps-2024.html" rel="alternate" title="Google at APS 2024" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjy22Hfq3RN4qRUJcSMUpIau4ueOIcQ219mDvfu4FNJ9kf5PBMUI0x4Uf9BhoIHtnFUhtvE72GCVYixldOZRSeePJfef0P87Pc_djQeGIZOhyxv9nKsQCc57357tr3npWdS5fyWxiGjex4NxMpOIB2JE1Z2qXdLnzLkFM075WstFJD77xVNS2T9hckWZyLf/s72-c/lockup_GoogleResearch_FullColor_Hero.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-1695264277638670894</id><published>2024-02-22T12:05:00.000-08:00</published><updated>2024-02-23T10:07:08.500-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Machine Intelligence"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Perception"/><title type="text">VideoPrism: A foundational visual encoder for video understanding</title><content type="html">&lt;span class="byline-author"&gt;Posted by Long Zhao, Senior Research Scientist, and Ting Liu, Senior Staff Software Engineer, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi4kKy9Vqp7LE__mAG3METzRxmp6Z5PCH8AyfXzxQ_mNeIgOwYitblprQbb1fOTSUDgNgdmgsm7QwyXgkBcUDs2iIkxGue1n1sxdaomCyAo_eZD1-NFJEbn0fct-gJSNNs_MXHQQCxA79hVbd2CHzg2Nkpw1RnsOQWLq4Y7A7mxXTAFjR9NEE42A6pMOaDi/s450/VideoPrismSample.gif" style="display: none;" /&gt;

&lt;p&gt;
An astounding number of videos are available on the Web, covering a variety of content from everyday moments people share to historical moments to scientific observations, each of which contains a unique record of the world. The right tools could help researchers analyze these videos, transforming how we understand the world around us.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt; 
&lt;p&gt;
Videos offer dynamic visual content far more rich than static images, capturing movement, changes, and dynamic relationships between entities. Analyzing this complexity, along with the immense diversity of publicly available video data, demands models that go beyond traditional image understanding. Consequently, many of the approaches that best perform on video understanding still rely on specialized models tailor-made for particular tasks. Recently, there has been exciting progress in this area using video foundation models (ViFMs), such as &lt;a href="https://arxiv.org/abs/2109.14084"&gt;VideoCLIP&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2212.03191"&gt;InternVideo&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2212.04979"&gt;VideoCoCa&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2303.16058"&gt;UMT&lt;/a&gt;. However, building a ViFM that handles the sheer diversity of video data remains a challenge.
&lt;/p&gt;
&lt;p&gt;
With the goal of building a single model for general-purpose video understanding, we introduce “&lt;a href="https://arxiv.org/abs/2402.13217"&gt;VideoPrism: A Foundational Visual Encoder for Video Understanding&lt;/a&gt;”. VideoPrism is a ViFM designed to handle a wide spectrum of video understanding tasks, including classification, localization, retrieval, captioning, and question answering (QA). We propose innovations in both the pre-training data as well as the modeling strategy. We pre-train VideoPrism on a massive and diverse dataset: 36 million high-quality video-text pairs and 582 million video clips with noisy or machine-generated parallel text. Our pre-training approach is designed for this hybrid data, to learn both from video-text pairs and the videos themselves. VideoPrism is incredibly easy to adapt to new video understanding challenges, and achieves state-of-the-art performance using a single frozen model.
&lt;/p&gt;&lt;p&gt;&lt;/p&gt;

&lt;video autoplay="" loop="" muted="" playsinline="" width="100%"&gt; &lt;source src="https://github.com/garyzhao/videoprism-blog/raw/main/teaser.mp4" type="video/mp4"&gt;&lt;/source&gt; &lt;/video&gt;
&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;VideoPrism is a general-purpose video encoder that enables state-of-the-art results over a wide spectrum of video understanding tasks, including classification, localization, retrieval, captioning, and question answering, by producing video representations from a single frozen model.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt; 

&lt;h2&gt;Pre-training data&lt;/h2&gt;


&lt;p&gt;
A powerful ViFM needs a very large collection of videos on which to train — similar to other foundation models (FMs), such as those for large language models (LLMs). Ideally, we would want the pre-training data to be a representative sample of all the videos in the world. While naturally most of these videos do not have perfect captions or descriptions, even imperfect text can provide useful information about the semantic content of the video.
&lt;/p&gt;
&lt;p&gt;
To give our model the best possible starting point, we put together a massive pre-training corpus consisting of several public and private datasets, including &lt;a href="https://rowanzellers.com/merlot/"&gt;YT-Temporal-180M&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2307.06942"&gt;InternVid&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2204.00679"&gt;VideoCC&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2007.14937"&gt;WTS-70M&lt;/a&gt;, etc. This includes 36 million carefully selected videos with high-quality captions, along with an additional 582 million clips with varying levels of noisy text (like auto-generated transcripts). To our knowledge, this is the largest and most diverse video training corpus of its kind.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgrhfnM1Rg_xbS1b3ZtydWc0M7zOchLpi5qdj65UaR3mOYbV8SQQqKhUhltYwmkPNqrULdeVeE1nU3gnRkjR7pE-yFaiVRC1al-BxZecsO0aojXFzSDhfv45oZoOBeYA93IiNeCGdnUryh4HLc3w7Qr2PX0fy6-4qFMTKBORA_PfHspp7Nr1OW0WnAvn-S9/s1999/image18.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="779" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgrhfnM1Rg_xbS1b3ZtydWc0M7zOchLpi5qdj65UaR3mOYbV8SQQqKhUhltYwmkPNqrULdeVeE1nU3gnRkjR7pE-yFaiVRC1al-BxZecsO0aojXFzSDhfv45oZoOBeYA93IiNeCGdnUryh4HLc3w7Qr2PX0fy6-4qFMTKBORA_PfHspp7Nr1OW0WnAvn-S9/s16000/image18.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Statistics on the video-text pre-training data. The large variations of the&amp;nbsp;&lt;a href="https://arxiv.org/abs/2104.14806"&gt;CLIP similarity scores&lt;/a&gt;&amp;nbsp;(the higher, the better) demonstrate the diverse caption quality of our pre-training data, which is a byproduct of the various ways used to harvest the text.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt; 

&lt;h2&gt;Two-stage training&lt;/h2&gt;


&lt;p&gt;
The VideoPrism model architecture stems from the standard &lt;a href="https://arxiv.org/abs/2010.11929"&gt;vision transformer&lt;/a&gt; (ViT) with a factorized design that sequentially encodes spatial and temporal information following &lt;a href="https://arxiv.org/abs/2103.15691"&gt;ViViT&lt;/a&gt;. Our training approach leverages both the high-quality video-text data and the video data with noisy text mentioned above. To start, we use &lt;a href="https://en.wikipedia.org/wiki/Self-supervised_learning#Contrastive_self-supervised_learning"&gt;contrastive learning&lt;/a&gt; (an approach that minimizes the distance between positive video-text pairs while maximizing the distance between negative video-text pairs) to teach our model to match videos with their own text descriptions, including imperfect ones. This builds a foundation for matching semantic language content to visual content.
&lt;/p&gt;
&lt;p&gt;
After video-text contrastive training, we leverage the collection of videos without text descriptions. Here, we build on the &lt;a href="https://arxiv.org/abs/2212.04500"&gt;masked video modeling framework&lt;/a&gt; to predict masked patches in a video, with a few improvements. We train the model to predict both the video-level global embedding and token-wise embeddings from the first-stage model to effectively leverage the knowledge acquired in that stage. We then randomly shuffle the predicted tokens to prevent the model from learning shortcuts.
&lt;/p&gt;
&lt;p&gt;
What is unique about VideoPrism’s setup is that we use two complementary pre-training signals: text descriptions and the visual content within a video. Text descriptions often focus on what things look like, while the video content provides information about movement and visual dynamics. This enables VideoPrism to excel in tasks that demand an understanding of both appearance and motion.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Results&lt;/h2&gt;


&lt;p&gt;
We conduct extensive evaluation on VideoPrism across four broad categories of video understanding tasks, including video classification and localization, video-text retrieval, video captioning, question answering, and scientific video understanding. VideoPrism achieves state-of-the-art performance on 30 out of 33 video understanding benchmarks — all with minimal adaptation of a single, frozen model.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgiUtXCxgEXrgAZJ2B-Mn8L0DP7VkFUfUbI1yLTgGYSbWtn_Q5AjgGRgi3yQ5PMB3fVFlHLzDP4yhlCeGaPpdXr5I1-TNYelYMUBYiXx16qNYTpqKwAqXX7-EFV-4Asn6qYFWOb6_5p71n5Zzxbt-ZeUy5yIj2aieKXl0LnFOqdhKXa56xm4ZoXbccYDz3H/s1999/image20.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1999" data-original-width="1959" height="640" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgiUtXCxgEXrgAZJ2B-Mn8L0DP7VkFUfUbI1yLTgGYSbWtn_Q5AjgGRgi3yQ5PMB3fVFlHLzDP4yhlCeGaPpdXr5I1-TNYelYMUBYiXx16qNYTpqKwAqXX7-EFV-4Asn6qYFWOb6_5p71n5Zzxbt-ZeUy5yIj2aieKXl0LnFOqdhKXa56xm4ZoXbccYDz3H/w628-h640/image20.png" width="628" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;VideoPrism compared to the previous best-performing FMs.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;div style="line-height: 40%;"&gt;&lt;br /&gt;
&lt;/div&gt; 

&lt;h3&gt;Classification and localization&lt;/h3&gt;


&lt;p&gt;
We evaluate VideoPrism on an existing large-scale video understanding benchmark (&lt;a href="https://arxiv.org/abs/2307.03166"&gt;VideoGLUE&lt;/a&gt;) covering classification and localization tasks. We find that (1) VideoPrism outperforms all of the other state-of-the-art FMs, and (2) no other single model consistently came in second place. This tells us that VideoPrism has learned to effectively pack a variety of video signals into one encoder — from semantics at different granularities to appearance and motion cues — and it works well across a variety of video sources.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhNnyg_lnLfwDIsJElqFwLKJleb1quzOR4h7X5jBf_bAnxwo_Em-_XLtWkkyMkyMPcLGdm0F25tLmccw3eK9qt6NN4LrLvfF45Wu8J2ylCqi4hPE-rFOwzmGuV8II6Nq8hileMNrS1lMwCuOHTVNGS04Dsxc7yVztaMCu0sRvuMUHnN4u9IKEvv2g8fRYWo/s1816/image12.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="742" data-original-width="1816" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhNnyg_lnLfwDIsJElqFwLKJleb1quzOR4h7X5jBf_bAnxwo_Em-_XLtWkkyMkyMPcLGdm0F25tLmccw3eK9qt6NN4LrLvfF45Wu8J2ylCqi4hPE-rFOwzmGuV8II6Nq8hileMNrS1lMwCuOHTVNGS04Dsxc7yVztaMCu0sRvuMUHnN4u9IKEvv2g8fRYWo/s16000/image12.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;VideoPrism outperforms state-of-the-art approaches (including &lt;a href="https://arxiv.org/abs/2103.00020"&gt;CLIP&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2104.11178"&gt;VATT&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2212.03191"&gt;InternVideo&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/2303.16058"&gt;UMT&lt;/a&gt;) on the &lt;a href="https://arxiv.org/abs/2307.03166"&gt;video understanding benchmark&lt;/a&gt;. In this plot, we show the absolute score differences compared with the previous best model to highlight the relative improvements of VideoPrism. On &lt;a href="http://vuchallenge.org/charades.html"&gt;Charades&lt;/a&gt;, &lt;a href="http://activity-net.org/"&gt;ActivityNet&lt;/a&gt;, &lt;a href="https://research.google.com/ava/"&gt;AVA&lt;/a&gt;, and &lt;a href="https://research.google.com/ava/"&gt;AVA-K&lt;/a&gt;, we use &lt;a href="https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision"&gt;mean average precision&lt;/a&gt; (mAP) as the evaluation metric. On the other datasets, we report top-1 accuracy.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt; 

&lt;h3&gt;Combining with LLMs&lt;/h3&gt;


&lt;p&gt;
We further explore combining VideoPrism with LLMs to unlock its ability to handle various video-language tasks. In particular, when paired with a text encoder (following &lt;a href="https://arxiv.org/abs/2111.07991"&gt;LiT&lt;/a&gt;) or a language decoder (such as &lt;a href="https://arxiv.org/abs/2305.10403"&gt;PaLM-2&lt;/a&gt;), VideoPrism can be utilized for video-text retrieval, video captioning, and video QA tasks. We compare the combined models on a broad and challenging set of vision-language benchmarks. VideoPrism sets the new state of the art on most benchmarks. From the visual results, we find that VideoPrism is capable of understanding complex motions and appearances in videos (e.g., the model can recognize the different colors of spinning objects on the window in the visual examples below). These results demonstrate that VideoPrism is strongly compatible with language models.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjd7V86xYM18_i3s0aemjiiYxaJeBiooZrEicQ5VVkLK3QnWTR96hKVsobSO4qRiN0f253JPX4y-T_h17E2Rx80PIVtVed0q499uCv42RzxZ7crkr21nuCR0zwalkSUX9FxIbjWVmlQGb1yx9Y5J8aVT_ROkY4DB1skUkk-bc9FaCc6tc-XLumHk5P65_UR/s1028/VideoPrismResults.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="932" data-original-width="1028" height="580" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjd7V86xYM18_i3s0aemjiiYxaJeBiooZrEicQ5VVkLK3QnWTR96hKVsobSO4qRiN0f253JPX4y-T_h17E2Rx80PIVtVed0q499uCv42RzxZ7crkr21nuCR0zwalkSUX9FxIbjWVmlQGb1yx9Y5J8aVT_ROkY4DB1skUkk-bc9FaCc6tc-XLumHk5P65_UR/w640-h580/VideoPrismResults.png" width="640" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;VideoPrism achieves competitive results compared with state-of-the-art approaches (including &lt;a href="https://arxiv.org/abs/2212.04979"&gt;VideoCoCa&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2303.16058"&gt;UMT&lt;/a&gt; and &lt;a href="https://arxiv.org/abs/2204.14198"&gt;Flamingo&lt;/a&gt;) on multiple video-text retrieval (top) and video captioning and video QA (bottom) benchmarks. We also show the absolute score differences compared with the previous best model to highlight the relative improvements of VideoPrism. We report the Recall@1 on &lt;a href="https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/"&gt;MASRVTT&lt;/a&gt;, &lt;a href="https://eric-xw.github.io/vatex-website/index.html"&gt;VATEX&lt;/a&gt;, and &lt;a href="https://cs.stanford.edu/people/ranjaykrishna/densevid/"&gt;ActivityNet&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/1411.5726"&gt;CIDEr score&lt;/a&gt; on &lt;a href="https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/"&gt;MSRVTT-Cap&lt;/a&gt;, &lt;a href="https://eric-xw.github.io/vatex-website/index.html"&gt;VATEX-Cap&lt;/a&gt;, and &lt;a href="http://youcook2.eecs.umich.edu/"&gt;YouCook2&lt;/a&gt;, top-1 accuracy on &lt;a href="https://github.com/xudejing/video-question-answering"&gt;MSRVTT-QA&lt;/a&gt; and &lt;a href="https://github.com/xudejing/video-question-answering"&gt;MSVD-QA&lt;/a&gt;, and &lt;a href="https://arxiv.org/abs/cmp-lg/9406033"&gt;WUPS index&lt;/a&gt; on &lt;a href="https://doc-doc.github.io/docs/nextqa.html"&gt;NExT-QA&lt;/a&gt;.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;video autoplay="" loop="" muted="" playsinline="" width="100%"&gt; &lt;source src="https://github.com/garyzhao/videoprism-blog/raw/main/snowball_water_bottle_drum.mp4" type="video/mp4"&gt;&lt;/source&gt; &lt;/video&gt;
&lt;video autoplay="" loop="" muted="" playsinline="" width="100%"&gt; &lt;source src="https://github.com/garyzhao/videoprism-blog/raw/main/spin_roller_skating.mp4" type="video/mp4"&gt;&lt;/source&gt; &lt;/video&gt;
&lt;video autoplay="" loop="" muted="" playsinline="" width="100%"&gt; &lt;source src="https://github.com/garyzhao/videoprism-blog/raw/main/making_ice_cream_ski_lifting.mp4" type="video/mp4"&gt;&lt;/source&gt; &lt;/video&gt;
&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;We show qualitative results using VideoPrism with a text encoder for video-text retrieval (first row) and adapted to a language decoder for video QA (second and third row). For video-text retrieval examples, the blue bars indicate the embedding similarities between the videos and the text queries.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt; 

&lt;h3&gt;Scientific applications&lt;/h3&gt;


&lt;p&gt;
Finally, we test VideoPrism on datasets used by scientists across domains, including fields such as ethology, behavioral neuroscience, and ecology. These datasets typically require domain expertise to annotate, for which we leverage existing scientific datasets open-sourced by the community including &lt;a href="https://data.caltech.edu/records/zrznw-w7386"&gt;Fly vs. Fly&lt;/a&gt;, &lt;a href="https://data.caltech.edu/records/s0vdx-0k302"&gt;CalMS21&lt;/a&gt;, &lt;a href="https://shirleymaxx.github.io/ChimpACT/"&gt;ChimpACT&lt;/a&gt;, and &lt;a href="https://dirtmaxim.github.io/kabr/"&gt;KABR&lt;/a&gt;. VideoPrism not only performs exceptionally well, but actually surpasses models designed specifically for those tasks. This suggests tools like VideoPrism have the potential to transform how scientists analyze video data across different fields.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3v-C36GWUp8CkaCVqFvaXYKW6-1SvCo99Ogiul-fSTkftyc-t4z5CNUgEWlJkRmzranQrYHldtBvjeJXsqdB4ZbgBkyaZv-_I9QE5U7kus_Z8QWlVqfzX0JfELSDPfGj9V4QqhUMwX_EkyPM-vG7pdYMXN0kj1-s98IZJl3U8CpvqoOHyAsuwXIVt7M4_/s1200/image5.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="742" data-original-width="1200" height="397" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3v-C36GWUp8CkaCVqFvaXYKW6-1SvCo99Ogiul-fSTkftyc-t4z5CNUgEWlJkRmzranQrYHldtBvjeJXsqdB4ZbgBkyaZv-_I9QE5U7kus_Z8QWlVqfzX0JfELSDPfGj9V4QqhUMwX_EkyPM-vG7pdYMXN0kj1-s98IZJl3U8CpvqoOHyAsuwXIVt7M4_/w640-h397/image5.png" width="640" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;VideoPrism outperforms the domain experts on various scientific benchmarks. We show the absolute score differences to highlight the relative improvements of VideoPrism. We report mean average precision (mAP) for all datasets, except for KABR which uses class-averaged top-1 accuracy.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt; 

&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
With VideoPrism, we introduce a powerful and versatile video encoder that sets a new standard for general-purpose video understanding. Our emphasis on both building a massive and varied pre-training dataset and innovative modeling techniques has been validated through our extensive evaluations. Not only does VideoPrism consistently outperform strong baselines, but its unique ability to generalize positions it well for tackling an array of real-world applications. Because of its potential broad use, we are committed to continuing further responsible research in this space, guided by our &lt;a href="http://ai.google/principles"&gt;AI Principles&lt;/a&gt;. We hope VideoPrism paves the way for future breakthroughs at the intersection of AI and video analysis, helping to realize the potential of ViFMs across domains such as scientific discovery, education, and healthcare.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;This blog post is made on behalf of all the VideoPrism authors: Long Zhao, Nitesh B. Gundavarapu, Liangzhe Yuan, Hao Zhou, Shen Yan, Jennifer J. Sun, Luke Friedman, Rui Qian, Tobias Weyand, Yue Zhao, Rachel Hornung, Florian Schroff, Ming-Hsuan Yang, David A. Ross, Huisheng Wang, Hartwig Adam, Mikhail Sirotenko, Ting Liu, and Boqing Gong. We sincerely thank David Hendon for their product management efforts, and Alex Siegman, Ramya Ganeshan, and Victor Gomes for their program and resource management efforts. We also thank Hassan Akbari, Sherry Ben, Yoni Ben-Meshulam, Chun-Te Chu, Sam Clearwater, Yin Cui, Ilya Figotin, Anja Hauth, Sergey Ioffe, Xuhui Jia, Yeqing Li, Lu Jiang, Zu Kim, Dan Kondratyuk, Bill Mark, Arsha Nagrani, Caroline Pantofaru, Sushant Prakash, Cordelia Schmid, Bryan Seybold, Mojtaba Seyedhosseini, Amanda Sadler, Rif A. Saurous, Rachel Stigler, Paul Voigtlaender, Pingmei Xu, Chaochao Yan, Xuan Yang, and Yukun Zhu for the discussions, support, and feedback that greatly contributed to this work. We are grateful to Jay Yagnik, Rahul Sukthankar, and Tomas Izo for their enthusiastic support for this project. Lastly, we thank Tom Small, Jennifer J. Sun, Hao Zhou, Nitesh B. Gundavarapu, Luke Friedman, and Mikhail Sirotenko for the tremendous help with making this blog post.&lt;/em&gt;
&lt;/p&gt;&lt;p&gt;&lt;/p&gt;</content><link href="http://blog.research.google/feeds/1695264277638670894/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/videoprism-foundational-visual-encoder.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1695264277638670894" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1695264277638670894" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/videoprism-foundational-visual-encoder.html" rel="alternate" title="VideoPrism: A foundational visual encoder for video understanding" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi4kKy9Vqp7LE__mAG3METzRxmp6Z5PCH8AyfXzxQ_mNeIgOwYitblprQbb1fOTSUDgNgdmgsm7QwyXgkBcUDs2iIkxGue1n1sxdaomCyAo_eZD1-NFJEbn0fct-gJSNNs_MXHQQCxA79hVbd2CHzg2Nkpw1RnsOQWLq4Y7A7mxXTAFjR9NEE42A6pMOaDi/s72-c/VideoPrismSample.gif" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-4343235509909091741</id><published>2024-02-21T12:15:00.000-08:00</published><updated>2024-02-21T12:15:36.694-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Differential Privacy"/><category scheme="http://www.blogger.com/atom/ns#" term="Gboard"/><category scheme="http://www.blogger.com/atom/ns#" term="On-device Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="Responsible AI"/><title type="text">Advances in private training for production on-device language models</title><content type="html">&lt;span class="byline-author"&gt;Posted by Zheng Xu, Research Scientist, and Yanxiang Zhang, Software Engineer, Google&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEifnCZ_XGUoUG0hESM0dF5B8Rsoqo4YrT_-uv0hlDM1iTADhtEEyEvBM4hOWT0rxgpVtZKyuFoj2xeXmkeXwGe-XTmvBuwBDJOCqgN8Ba7Wcjh_s1seWUaCRl1xNpNe_6MqxcFFZoAvhfCge5vq9UATjXG_BnTiGdQ6YLLo7AK7ABS3KLFMKmjAtA1gkcBk/s1600/GBoard%20PrivacyHero.gif" style="display: none;" /&gt;

&lt;p&gt;
Language models (LMs) trained to predict the next word given input text are the key technology for many applications [&lt;a href="https://blog.google/technology/ai/google-palm-2-ai-large-language-model/"&gt;1&lt;/a&gt;, &lt;a href="https://blog.google/technology/ai/google-gemini-ai/"&gt;2&lt;/a&gt;]. In &lt;a href="https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&amp;amp;hl=en_US&amp;amp;gl=US"&gt;Gboard&lt;/a&gt;, LMs are used to improve users’ typing experience by supporting features like &lt;a href="https://arxiv.org/abs/1811.03604"&gt;next word prediction&lt;/a&gt; (NWP), &lt;a href="https://support.google.com/gboard/answer/7068415"&gt;Smart Compose&lt;/a&gt;,&lt;a href="https://support.google.com/gboard/answer/7068415"&gt; smart completion&lt;/a&gt; and &lt;a href="https://support.google.com/gboard/answer/7068415"&gt;suggestion&lt;/a&gt;, &lt;a href="https://support.google.com/gboard/answer/2811346"&gt;slide to type&lt;/a&gt;&lt;span style="text-decoration: underline;"&gt;,&lt;/span&gt; and &lt;a href="https://support.google.com/gboard/answer/7068415"&gt;proofread&lt;/a&gt;. Deploying models on users’ devices rather than enterprise servers has advantages like lower latency and better privacy for model usage. While training on-device models directly from user data effectively improves the utility performance for applications such as NWP and &lt;a href="https://blog.research.google/2021/11/predicting-text-selections-with.html"&gt;smart text selection&lt;/a&gt;, protecting the privacy of user data for model training is important. 
&lt;/p&gt;

&lt;a name='more'&gt;&lt;/a&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiWvaPvikHjeVBb9njeoP2z499_LU0a4VfEgI2kOVxYEoApqgZ49-Ej_TpY6pyoy9HKU2jASzSBsKhdXuOhP-ykpsK_makFmWzVF67BPS3PSpRrCIxC0hYHogBVcDM74AXmjD5hh2mP22tPmXQqEkOak9QXXLyJOCsJB94dv0P-W3IINYyah2O-nF1HLTXE/s1996/image45.gif" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1600" data-original-width="1996" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiWvaPvikHjeVBb9njeoP2z499_LU0a4VfEgI2kOVxYEoApqgZ49-Ej_TpY6pyoy9HKU2jASzSBsKhdXuOhP-ykpsK_makFmWzVF67BPS3PSpRrCIxC0hYHogBVcDM74AXmjD5hh2mP22tPmXQqEkOak9QXXLyJOCsJB94dv0P-W3IINYyah2O-nF1HLTXE/s16000/image45.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Gboard features powered by on-device language models.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
In this blog we discuss how years of research advances now power the private training of Gboard LMs, since the proof-of-concept development of &lt;a href="https://blog.research.google/2017/04/federated-learning-collaborative.html"&gt;federated learning&lt;/a&gt; (FL) in 2017 and formal &lt;a href="https://blog.research.google/2022/02/federated-learning-with-formal.html"&gt;differential privacy&lt;/a&gt; (DP) guarantees in 2022. &lt;a href="https://blog.research.google/2017/04/federated-learning-collaborative.html"&gt;FL&lt;/a&gt; enables mobile phones to collaboratively learn a model while keeping all the training data on device, and &lt;a href="https://en.wikipedia.org/wiki/Differential_privacy"&gt;DP&lt;/a&gt; provides a quantifiable measure of data anonymization. Formally, DP is often characterized by (&lt;em&gt;ε&lt;/em&gt;, &lt;em&gt;δ&lt;/em&gt;) with smaller values representing stronger guarantees. Machine learning (ML) models are considered to have &lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;reasonable DP guarantees for ε=10 and strong DP guarantees for ε=1&lt;/a&gt; when &lt;em&gt;δ&lt;/em&gt; is small. 
&lt;/p&gt;
&lt;p&gt;
As of today, all NWP neural network LMs in Gboard are trained with FL with formal DP guarantees, and all future launches of Gboard LMs trained on user data require DP. These 30+ Gboard on-device LMs are launched in 7+ languages and 15+ countries, and satisfy (&lt;em&gt;ɛ&lt;/em&gt;, &lt;em&gt;δ&lt;/em&gt;)-DP guarantees of small &lt;em&gt;δ&lt;/em&gt; of 10&lt;sup&gt;-10&lt;/sup&gt; and ɛ between 0.994 and 13.69. To the best of our knowledge, this is the largest known deployment of user-level DP in production at Google or anywhere, and the first time a strong DP guarantee of &lt;em&gt;ɛ&lt;/em&gt; &amp;lt; 1 is announced for models trained directly on user data. 
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Privacy principles and practices in Gboard&lt;/h2&gt;


&lt;p&gt;
In “&lt;a href="https://arxiv.org/abs/2306.14793"&gt;Private Federated Learning in Gboard&lt;/a&gt;”, we discussed how different &lt;a href="https://queue.acm.org/detail.cfm?id=3501293"&gt;privacy principles&lt;/a&gt; are currently reflected in production models, including:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;&lt;em&gt;Transparency and user control&lt;/em&gt;: We provide disclosure of what data is used, what purpose it is used for, how it is processed in various channels, and how Gboard users can easily &lt;a href="https://support.google.com/gboard/answer/12373137"&gt;configure&lt;/a&gt; the data usage in learning models. 

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Data minimization&lt;/em&gt;: FL immediately aggregates only focused updates that improve a specific model. &lt;a href="https://eprint.iacr.org/2017/281.pdf"&gt;Secure aggregation&lt;/a&gt; (SecAgg) is an encryption method to further guarantee that only aggregated results of the ephemeral updates can be accessed.   

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Data anonymization&lt;/em&gt;: DP is applied by the server to prevent models from memorizing the unique information in individual user’s training data. 

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Auditability and verifiability&lt;/em&gt;: We have made public the key algorithmic approaches and privacy accounting in open-sourced code (&lt;a href="https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/aggregators/differential_privacy.py"&gt;TFF aggregator&lt;/a&gt;, &lt;a href="https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py"&gt;TFP DPQuery&lt;/a&gt;, &lt;a href="https://github.com/google-research/federated/blob/master/dp_ftrl/blogpost_supplemental_privacy_accounting.ipynb"&gt;DP accounting&lt;/a&gt;, and &lt;a href="https://github.com/google/federated-compute"&gt;FL system&lt;/a&gt;). 
&lt;/li&gt;
&lt;/ul&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;A brief history&lt;/h3&gt;


&lt;p&gt;
In recent years, FL has become the default method for training &lt;a href="https://arxiv.org/abs/1811.03604"&gt;Gboard on-device LMs&lt;/a&gt; from user data. In 2020, a DP mechanism that &lt;a href="https://arxiv.org/abs/1710.06963"&gt;clips and adds noise&lt;/a&gt; to model updates was used to &lt;a href="https://arxiv.org/abs/2009.10031"&gt;prevent memorization&lt;/a&gt; for training the Spanish LM in Spain, which satisfies finite DP guarantees (&lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;Tier 3&lt;/a&gt; described in “&lt;a href="https://arxiv.org/abs/2303.00654"&gt;How to DP-fy ML“&lt;/a&gt; guide). In 2022, with the help of the &lt;a href="https://arxiv.org/abs/2103.00039"&gt;DP-Follow-The-Regularized-Leader (DP-FTRL) algorithm&lt;/a&gt;, the Spanish LM became the first production neural network trained directly on user data announced with &lt;a href="https://blog.research.google/2022/02/federated-learning-with-formal.html"&gt;a formal DP guarantee of (ε=8.9, δ=10&lt;sup&gt;-10&lt;/sup&gt;)-DP&lt;/a&gt; (equivalent to the reported &lt;em&gt;&lt;a href="https://blog.research.google/2022/02/federated-learning-with-formal.html"&gt;ρ=0.81&lt;/a&gt;&lt;/em&gt; &lt;a href="https://arxiv.org/abs/1605.02065"&gt;zero-Concentrated-Differential-Privacy&lt;/a&gt;), and therefore satisfies &lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;reasonable privacy guarantees&lt;/a&gt; (&lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;Tier 2&lt;/a&gt;). 
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Differential privacy by default in federated learning &lt;/h2&gt;


&lt;p&gt;
In “&lt;a href="https://arxiv.org/abs/2305.18465"&gt;Federated Learning of Gboard Language Models with Differential Privacy&lt;/a&gt;”, we announced that all the NWP neural network LMs in Gboard have DP guarantees, and all future launches of Gboard LMs trained on user data require DP guarantees. DP is enabled in FL by applying the following practices:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;Pre-train the model with the &lt;a href="https://arxiv.org/abs/2010.11934"&gt;multilingual&lt;/a&gt; &lt;a href="https://arxiv.org/abs/1910.10683"&gt;C4&lt;/a&gt; dataset.  

&lt;/li&gt;&lt;li&gt;Via simulation experiments on public datasets, find a large DP-noise-to-signal ratio that allows for high utility. Increasing the number of clients contributing to one round of model update improves privacy while keeping the noise ratio fixed for good utility, up to the point the DP target is met, or the maximum allowed by the system and the size of the population.

&lt;/li&gt;&lt;li&gt;Configure the parameter to restrict the frequency each client can contribute (e.g., once every few days) based on computation budget and estimated population in &lt;a href="https://arxiv.org/abs/1902.01046"&gt;the FL system&lt;/a&gt;. 

&lt;/li&gt;&lt;li&gt;Run &lt;a href="https://arxiv.org/abs/2103.00039"&gt;DP-FTRL&lt;/a&gt; training with limits on the magnitude of per-device updates chosen either via &lt;a href="https://github.com/tensorflow/federated/commit/ee9d08368828ea730662e5e2b3a90e103368b6b6"&gt;adaptive clipping&lt;/a&gt;, or fixed based on experience. 
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
SecAgg can be additionally applied by adopting the &lt;a href="https://blog.research.google/2023/03/distributed-differential-privacy-for.html"&gt;advances in improving computation and communication for scales and sensitivity&lt;/a&gt;.
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEht2ZweKyxBRqShB6i41lTpZmfS2gEi2rbNHFGgT-36di1HMxwV6caxFJ2lUXpznxuXYHEb928yfHwueojKlB-gxfKfT4aEv-_2mUlO5zlaWNPceMDGdnOVWp4M8T5qCzMPTuinPOtRy1WmXMtsaSpNpMLvokQKlOnWYFMJF0tXbhmc-dkpI-o7T4FBn8-N/s1600/image3.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1000" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEht2ZweKyxBRqShB6i41lTpZmfS2gEi2rbNHFGgT-36di1HMxwV6caxFJ2lUXpznxuXYHEb928yfHwueojKlB-gxfKfT4aEv-_2mUlO5zlaWNPceMDGdnOVWp4M8T5qCzMPTuinPOtRy1WmXMtsaSpNpMLvokQKlOnWYFMJF0tXbhmc-dkpI-o7T4FBn8-N/s16000/image3.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Federated learning with differential privacy and (SecAgg).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Reporting DP guarantees&lt;/h3&gt;


&lt;p&gt;
The DP guarantees of launched Gboard NWP LMs are visualized in the barplot below. The &lt;em&gt;x&lt;/em&gt;-axis shows LMs labeled by language-locale and trained on corresponding populations; the &lt;em&gt;y&lt;/em&gt;-axis shows the &lt;em&gt;ε&lt;/em&gt; value when &lt;em&gt;δ&lt;/em&gt; is fixed to a small value of 10&lt;sup&gt;-10&lt;/sup&gt; for &lt;a href="https://www.iacr.org/archive/eurocrypt2006/40040493/40040493.pdf"&gt;(ε, δ)-DP&lt;/a&gt; (lower is better). The utility of these models are either significantly better than previous non-neural models in production, or comparable with previous LMs without DP, measured based on user-interactions metrics during A/B testing. For example, by applying the best practices, the DP guarantee of the Spanish model in Spain is improved from &lt;em&gt;&lt;a href="https://blog.research.google/2022/02/federated-learning-with-formal.html"&gt;ε=8.9&lt;/a&gt;&lt;/em&gt; to &lt;em&gt;ε&lt;/em&gt;=5.37. SecAgg is additionally used for training the Spanish model in Spain and English model in the US. More details of the DP guarantees are reported in &lt;a href="https://arxiv.org/abs/2305.18465"&gt;the appendix &lt;/a&gt;following the &lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;guidelines outlined&lt;/a&gt; in “&lt;a href="https://arxiv.org/abs/2303.00654"&gt;How to DP-fy ML&lt;/a&gt;”. 
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Towards stronger DP guarantees&lt;/h2&gt;


&lt;p&gt;
The &lt;em&gt;ε&lt;/em&gt;~10 DP guarantees of many launched LMs are already considered &lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;reasonable&lt;/a&gt; for ML models in practice, while the journey of DP FL in Gboard continues for improving user typing experience while protecting data privacy. We are excited to announce that, for the first time, production LMs of Portuguese in Brazil and Spanish in Latin America are trained and launched with a DP guarantee of  &lt;em&gt;ε&lt;/em&gt; ≤ 1, which satisfies &lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;Tier 1 strong privacy guarantees&lt;/a&gt;. Specifically, the (&lt;em&gt;ε&lt;/em&gt;=0.994, &lt;em&gt;δ&lt;/em&gt;=10&lt;sup&gt;-10&lt;/sup&gt;)-DP guarantee is achieved by running the advanced &lt;a href="https://arxiv.org/abs/2306.08153"&gt;Matrix Factorization DP-FTRL&lt;/a&gt; (MF-DP-FTRL) algorithm, with 12,000+ devices participating in every training round of server model update larger than the &lt;a href="https://arxiv.org/abs/2305.18465"&gt;common setting of 6500+ devices&lt;/a&gt;, and a carefully configured policy to restrict each client to at most participate twice in the total 2000 rounds of training in 14 days in the large Portuguese user population of Brazil. Using a similar setting, the es-US Spanish LM was trained in a large population combining multiple countries in Latin America to achieve (&lt;em&gt;ε&lt;/em&gt;=0.994, &lt;em&gt;δ&lt;/em&gt;=10&lt;sup&gt;-10&lt;/sup&gt;)-DP. The &lt;em&gt;ε&lt;/em&gt; ≤ 1 es-US model significantly improved the utility in many countries, and launched in Colombia, Ecuador, Guatemala, Mexico, and Venezuela. For the smaller population in Spain, the DP guarantee of es-ES LM is improved from &lt;em&gt;&lt;a href="https://arxiv.org/abs/2305.18465"&gt;ε=5.37&lt;/a&gt;&lt;/em&gt; to &lt;em&gt;ε&lt;/em&gt;=3.42 by only replacing &lt;a href="https://arxiv.org/abs/2103.00039"&gt;DP-FTRL&lt;/a&gt; with &lt;a href="https://arxiv.org/abs/2306.08153"&gt;MF-DP-FTRL&lt;/a&gt; without increasing the number of devices participating every round. More technical details are disclosed in the &lt;a href="https://colab.sandbox.google.com/github/google-research/federated/blob/master/mf_dpftrl_matrices/privacy_accounting.ipynb"&gt;colab&lt;/a&gt; for privacy accounting. 
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgp1yNOAbd8IRoisQDX-OHq-a8PUDH2V1OF7btRsUXI86-tuEXwrR8otAGEqPN8J2HGcpH9aB25s04Nybm_Vn6bpRmfD_AHnHYkGJtld7ockal6mhdRXcsA-M6rf3vM7kzQ5hXfdPbw9hk7bsQU8EV4ul5QAn3Hw4b1yXIKjnokfhrkEF0hNXGt9DbLU3yk/s1999/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="709" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgp1yNOAbd8IRoisQDX-OHq-a8PUDH2V1OF7btRsUXI86-tuEXwrR8otAGEqPN8J2HGcpH9aB25s04Nybm_Vn6bpRmfD_AHnHYkGJtld7ockal6mhdRXcsA-M6rf3vM7kzQ5hXfdPbw9hk7bsQU8EV4ul5QAn3Hw4b1yXIKjnokfhrkEF0hNXGt9DbLU3yk/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;DP guarantees for Gboard NWP LMs (the purple bar represents the first es-ES launch of ε=8.9; cyan bars represent privacy improvements for models trained with &lt;a href="https://arxiv.org/abs/2306.08153"&gt;MF-DP-FTRL&lt;/a&gt;; &lt;a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html"&gt;tiers &lt;/a&gt;are from “&lt;a href="https://arxiv.org/abs/2303.00654"&gt;How to DP-fy ML&lt;/a&gt;“ guide; en-US* and es-ES* are additionally trained with SecAgg).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Discussion and next steps&lt;/h2&gt;


&lt;p&gt;
Our experience suggests that DP can be achieved in practice through system algorithm co-design on client participation, and that both privacy and utility can be strong when populations are large &lt;em&gt;and&lt;/em&gt; a large number of devices' contributions are aggregated. Privacy-utility-computation trade-offs can be improved by &lt;a href="https://arxiv.org/abs/2305.18465"&gt;using public data&lt;/a&gt;, the &lt;a href="https://arxiv.org/abs/2306.08153"&gt;new MF-DP-FTRL algorithm&lt;/a&gt;, &lt;a href="https://github.com/google/differential-privacy"&gt;and tightening accounting&lt;/a&gt;. With these techniques, a strong DP guarantee of &lt;em&gt;ε&lt;/em&gt; ≤ 1 is possible but still challenging. Active research on empirical privacy auditing [&lt;a href="https://arxiv.org/abs/2302.03098"&gt;1&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2305.08846"&gt;2&lt;/a&gt;] suggests that DP models are potentially more private than the worst-case DP guarantees imply. While we keep pushing the frontier of algorithms, which dimension of privacy-utility-computation should be prioritized?
&lt;/p&gt;
&lt;p&gt;
We are actively working on all privacy aspects of ML, including extending DP-FTRL to &lt;a href="https://blog.research.google/2023/03/distributed-differential-privacy-for.html"&gt;distributed DP&lt;/a&gt; and improving &lt;a href="https://arxiv.org/abs/2306.14793"&gt;auditability and verifiability&lt;/a&gt;. &lt;a href="https://en.wikipedia.org/wiki/Trusted_execution_environment"&gt;Trusted Execution Environment&lt;/a&gt; opens the opportunity for substantially increasing the model size with verifiable privacy. The recent &lt;a href="https://blog.google/technology/ai/google-gemini-ai/"&gt;breakthrough in large LMs&lt;/a&gt; (LLMs) motivates us to &lt;a href="https://arxiv.org/abs/2305.12132"&gt;rethink&lt;/a&gt; the usage of &lt;a href="https://arxiv.org/abs/2212.06470"&gt;public&lt;/a&gt; information in private training and more future interactions between LLMs, on-device LMs, and Gboard production.  
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgments&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;The authors would like to thank Peter Kairouz, Brendan McMahan, and Daniel Ramage for their early feedback on the blog post itself, Shaofeng Li and Tom Small for helping with the animated figures, and the teams at Google that helped with algorithm design, infrastructure implementation, and production maintenance. The collaborators below directly contribute to the presented results:&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;
&lt;em&gt;Research and algorithm development: Galen Andrew, Stanislav Chiknavaryan, Christopher A. Choquette-Choo, Arun Ganesh, Peter Kairouz, Ryan McKenna, H. Brendan McMahan, Jesse Rosenstock, Timon Van Overveldt, Keith Rush, Shuang Song, Thomas Steinke, Abhradeep Guha Thakurta, Om Thakkar, and Yuanbo Zhang.&lt;/em&gt;
&lt;/p&gt;
&lt;p&gt;
&lt;em&gt;Infrastructure, production and leadership support: Mingqing Chen, Stefan Dierauf, Billy Dou, Hubert Eichner, Zachary Garrett, Jeremy Gillula, Jianpeng Hou, Hui Li, Xu Liu, Wenzhi Mao, Brett McLarnon, Mengchen Pei, Daniel Ramage, Swaroop Ramaswamy, Haicheng Sun, Andreas Terzis, Yun Wang, Shanshan Wu, Yu Xiao, and Shumin Zhai.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/4343235509909091741/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/advances-in-private-training-for.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/4343235509909091741" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/4343235509909091741" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/advances-in-private-training-for.html" rel="alternate" title="Advances in private training for production on-device language models" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEifnCZ_XGUoUG0hESM0dF5B8Rsoqo4YrT_-uv0hlDM1iTADhtEEyEvBM4hOWT0rxgpVtZKyuFoj2xeXmkeXwGe-XTmvBuwBDJOCqgN8Ba7Wcjh_s1seWUaCRl1xNpNe_6MqxcFFZoAvhfCge5vq9UATjXG_BnTiGdQ6YLLo7AK7ABS3KLFMKmjAtA1gkcBk/s72-c/GBoard%20PrivacyHero.gif" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-5605933033299261025</id><published>2024-02-14T10:32:00.000-08:00</published><updated>2024-02-14T10:32:25.557-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Deep Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="Supervised Learning"/><title type="text">Learning the importance of training data under concept drift</title><content type="html">&lt;span class="byline-author"&gt;Posted by Nishant Jain, Pre-doctoral Researcher, and Pradeep Shenoy, Research Scientist, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgUeskw4YD6cFTpLaRnv7OwMsljyeipfAb1riYxIuBsiWd6TBmUXMJ4QoI9tlvUzWX9NzBbEjz3-P2Zl2kuXe5BrVclmqQFrLButoya5phiEELq1azrhsIaGaCz-ov_jXaMsFrGRDE0EjotyRQPOX3xV5MAkVJfKp9xecX4t2CoLBiZ8r2RpZ25Y5KRitFG/s1600/temporalreweightinghero.png" style="display: none;" /&gt;

&lt;p&gt;
The constantly changing nature of the world around us poses a significant challenge for the development of AI models. Often, models are trained on longitudinal data with the hope that the training data used will accurately represent inputs the model may receive in the future. More generally, the default assumption that all training data are equally relevant often breaks in practice. For example, the figure below shows images from the &lt;a href="https://arxiv.org/abs/2201.06289"&gt;CLEAR&lt;/a&gt; nonstationary learning benchmark, and it illustrates how visual features of objects evolve significantly over a 10 year span (a phenomenon we refer to as &lt;em&gt;slow concept drift&lt;/em&gt;), posing a challenge for object categorization models. 
&lt;/p&gt;

&lt;a name='more'&gt;&lt;/a&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjBAkCetRQiAPA4cmiXvtwa2SJ0pMwvRYDcuL7rQEDHxEgi9lAyU69bBeeEw-_k182BITn4w2WtdE5QfUwaF-Ny-Dkai-pLeHV23mlgAwrX_0le28l5hba9q9QUO3LeYl2jgkPGkKcLW7dtnGFMiY7PrZbpigSggAiOSrRB8X9eQZGHLE8H7TZoxYy4AD2Q/s1999/image4.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="662" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjBAkCetRQiAPA4cmiXvtwa2SJ0pMwvRYDcuL7rQEDHxEgi9lAyU69bBeeEw-_k182BITn4w2WtdE5QfUwaF-Ny-Dkai-pLeHV23mlgAwrX_0le28l5hba9q9QUO3LeYl2jgkPGkKcLW7dtnGFMiY7PrZbpigSggAiOSrRB8X9eQZGHLE8H7TZoxYy4AD2Q/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Sample images from the CLEAR benchmark. (Adapted from Lin et al&lt;a href="https://arxiv.org/abs/2201.06289"&gt;.&lt;/a&gt;)&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br&gt;


&lt;p&gt;
Alternative approaches, such as &lt;a href="https://en.wikipedia.org/wiki/Online_machine_learning"&gt;online&lt;/a&gt; and &lt;a href="https://wiki.continualai.org/the-continualai-wiki/introduction-to-continual-learning"&gt;continual learning&lt;/a&gt;, repeatedly update a model with small amounts of recent data in order to keep it current. This implicitly prioritizes recent data, as the learnings from past data are gradually erased by subsequent updates. However in the real world, different kinds of information lose relevance at different rates, so there are two key issues: 1) By design they focus &lt;em&gt;exclusively&lt;/em&gt; on the most recent data and lose any signal from older data that is erased. 2) Contributions from data instances decay &lt;em&gt;uniformly over time&lt;/em&gt; irrespective of the contents of the data.
&lt;/p&gt;
&lt;p&gt;
In our recent work, “&lt;a href="https://arxiv.org/abs/2212.05908"&gt;Instance-Conditional Timescales of Decay for Non-Stationary Learning&lt;/a&gt;”, we propose to assign each instance an importance score during training in order to maximize model performance on future data. To accomplish this, we employ an auxiliary model that produces these scores using the training instance as well as its age. This model is jointly learned with the primary model. We address both the above challenges and achieve significant gains over other robust learning methods on a range of benchmark datasets for nonstationary learning. For instance, on a &lt;a href="https://arxiv.org/abs/2108.09020"&gt;recent large-scale benchmark&lt;/a&gt; for nonstationary learning (~39M photos over a 10 year period), we show up to 15% relative accuracy gains through learned reweighting of training data.
&lt;/p&gt;



&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;The challenge of concept drift for supervised learning&lt;/h2&gt;


&lt;p&gt;
To gain quantitative insight into slow concept drift, we built classifiers on a &lt;a href="https://arxiv.org/abs/2108.09020"&gt;recent photo categorization task&lt;/a&gt;, comprising roughly 39M photographs sourced from social media websites over a 10 year period. We compared offline training, which iterated over all the training data multiple times in random order, and continual training, which iterated multiple times over each month of data in sequential (temporal) order. We measured model accuracy both during the training period and during a subsequent period where both models were frozen, i.e., not updated further on new data (shown below). At the end of the training period (left panel, x-axis = 0), both approaches have seen the same amount of data, but show a large performance gap. This is due to &lt;a href="https://www.sciencedirect.com/science/article/abs/pii/S0079742108605368"&gt;catastrophic forgetting&lt;/a&gt;, a problem in continual learning where a model’s knowledge of data from early on in the training sequence is diminished in an uncontrolled manner. On the other hand, forgetting has its advantages — over the test period (shown on the right), the continual trained model degrades much less rapidly than the offline model because it is less dependent on older data. The decay of both models’ accuracy in the test period is confirmation that the data is indeed evolving over time, and both models become increasingly less relevant.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEizQmgaL3NNsCLWbeTndyOxPikcGKqQIrpDisMVTy-7eAIxamEv3Klpncd5B4SB19yNnPmpySlfAz_hPN8x4zV7o0LPmcLKEnyVJBctKuLF8plITBmDz3BTR2aPHqlKarPPHZHpp0EY0M3HA9l5oV_IOaQS5UzS-uMaNq3Fi1D1qHUYJ6XC-4t0_xS91fnw/s1554/image2.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="616" data-original-width="1554" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEizQmgaL3NNsCLWbeTndyOxPikcGKqQIrpDisMVTy-7eAIxamEv3Klpncd5B4SB19yNnPmpySlfAz_hPN8x4zV7o0LPmcLKEnyVJBctKuLF8plITBmDz3BTR2aPHqlKarPPHZHpp0EY0M3HA9l5oV_IOaQS5UzS-uMaNq3Fi1D1qHUYJ6XC-4t0_xS91fnw/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparing offline and continually trained models on the photo classification task.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;br&gt;
&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Time-sensitive reweighting of training data&lt;/h2&gt;


&lt;p&gt;
We design a method combining the benefits of offline learning (the flexibility of effectively reusing all available data) and continual learning (the ability to downplay older data) to address slow concept drift. We build upon offline learning, then add careful control over the influence of past data and an optimization objective, both designed to reduce model decay in the future. 
&lt;/p&gt;
&lt;p&gt;
Suppose we wish to train a model, &lt;em&gt;M&lt;/em&gt;,&lt;em&gt; &lt;/em&gt;given some training data collected over time. We propose to also train a helper model that assigns a weight to each point based on its contents and age. This weight scales the contribution from that data point in the training objective for &lt;em&gt;M&lt;/em&gt;. The objective of the weights is to improve the performance of &lt;em&gt;M&lt;/em&gt; on future data. 
&lt;/p&gt;
&lt;p&gt;
In &lt;a href="https://arxiv.org/abs/2212.05908"&gt;our work&lt;/a&gt;, we describe how the helper model can be &lt;em&gt;meta-learned, &lt;/em&gt;i.e., learned alongside &lt;em&gt;M&lt;/em&gt; in a manner that helps the learning of the model &lt;em&gt;M&lt;/em&gt; itself. A key design choice of the helper model is that we separated out instance- and age-related contributions in a factored manner. Specifically, we set the weight by combining contributions from multiple different fixed timescales of decay, and learn an approximate “assignment” of a given instance to its most suited timescales. We find in our experiments that this form of the helper model outperforms many other alternatives we considered, ranging from unconstrained joint functions to a single timescale of decay (exponential or linear), due to its combination of simplicity and expressivity. Full details may be found in the &lt;a href="https://arxiv.org/abs/2212.05908"&gt;paper&lt;/a&gt;. 
&lt;/p&gt;



&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Instance weight scoring&lt;/h2&gt;


&lt;p&gt;
The top figure below shows that our learned helper model indeed up-weights more modern-looking objects in the &lt;a href="https://arxiv.org/abs/2201.06289"&gt;CLEAR object recognition challenge&lt;/a&gt;; older-looking objects are correspondingly down-weighted. On closer examination (bottom figure below, gradient-based &lt;a href="https://arxiv.org/abs/1610.02391"&gt;feature importance&lt;/a&gt; assessment), we see that the helper model focuses on the primary object within the image, as opposed to, e.g., background features that may spuriously be correlated with instance age.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEggQImnpFiW7s3jeT9qoxQOM1kT8vIaHihnlAPusLRx8lJCaxyB7Lzhewn7J6qTiz9-qkWBJzzxLj-uHXhlB94WBMUVRsAgqZVBMBAnDaHGeCe6evZOo6hYgR5oXImP5vO9ZUNcF1q3Bpvau94hM9D71xwOGRqm9c8lJ6ixrB69w_JjneqW5JGcg_u6ZW2J/s1999/image1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="499" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEggQImnpFiW7s3jeT9qoxQOM1kT8vIaHihnlAPusLRx8lJCaxyB7Lzhewn7J6qTiz9-qkWBJzzxLj-uHXhlB94WBMUVRsAgqZVBMBAnDaHGeCe6evZOo6hYgR5oXImP5vO9ZUNcF1q3Bpvau94hM9D71xwOGRqm9c8lJ6ixrB69w_JjneqW5JGcg_u6ZW2J/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Sample images from the &lt;a href="https://arxiv.org/abs/2201.06289"&gt;CLEAR&lt;/a&gt; benchmark (camera &amp;amp; computer categories) assigned the highest and lowest weights respectively by our helper model.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhiKCafyxNrHJUkwV3KjoFMJk_v9WSPlzfMyYa-TZCODZdBNCnUOLOZogf9njyGQp_TWzCZ-a6-P5smLhSyeHVFd_jaSBbmS9soN5A5AF6oTq_OWvk-xOWgKaDCIFYz8mhe-GoVEZ56QSsIpKxDduNmCA0ORnf_kgW8ph0uZci8UBCQDBHs0j4Nq5hb5J7e/s1999/image5.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="339" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhiKCafyxNrHJUkwV3KjoFMJk_v9WSPlzfMyYa-TZCODZdBNCnUOLOZogf9njyGQp_TWzCZ-a6-P5smLhSyeHVFd_jaSBbmS9soN5A5AF6oTq_OWvk-xOWgKaDCIFYz8mhe-GoVEZ56QSsIpKxDduNmCA0ORnf_kgW8ph0uZci8UBCQDBHs0j4Nq5hb5J7e/s16000/image5.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Feature importance analysis of our helper model on sample images from the &lt;a href="https://arxiv.org/abs/2201.06289"&gt;CLEAR&lt;/a&gt; benchmark.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;br&gt;

&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Results&lt;/h2&gt;

&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h3&gt;Gains on large-scale data &lt;/h3&gt;


&lt;p&gt;
We first study the large-scale &lt;a href="https://arxiv.org/abs/2108.09020"&gt;photo categorization task&lt;/a&gt; (PCAT) on the &lt;a href="https://arxiv.org/abs/1503.01817"&gt;YFCC100M dataset&lt;/a&gt; discussed earlier, using the first five years of data for training and the next five years as test data. Our method (shown in red below) improves substantially over the no-reweighting baseline (black) as well as many other robust learning techniques. Interestingly, our method deliberately trades off accuracy on the distant past (training data unlikely to reoccur in the future) in exchange for marked improvements in the test period. Also, as desired, our method degrades less than other baselines in the test period. 
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgLKQZo3e80Ttgw64eHAndZZc6BMXKBNLXAPTQZDP1tsFEQZpGckd6fzqG0aC1x_b5HQmiYlp6AzgbQ3gYRGVcHEZvhnPiDVsl1rxKh3vjVtqXJd20xp5og5yowR2SmyvqNdhhaSuNT5IY_rm_SJanFAsM4jt1Pf_TChyphenhyphenK8y0mNi2Jji1oDWcSiH_7vaC7b/s800/image3.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="600" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgLKQZo3e80Ttgw64eHAndZZc6BMXKBNLXAPTQZDP1tsFEQZpGckd6fzqG0aC1x_b5HQmiYlp6AzgbQ3gYRGVcHEZvhnPiDVsl1rxKh3vjVtqXJd20xp5og5yowR2SmyvqNdhhaSuNT5IY_rm_SJanFAsM4jt1Pf_TChyphenhyphenK8y0mNi2Jji1oDWcSiH_7vaC7b/s16000/image3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparison of our method and relevant baselines on the PCAT dataset.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br&gt;


&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h3&gt;Broad applicability&lt;/h3&gt;


&lt;p&gt;
We validated our findings on a wide range of nonstationary learning challenge datasets sourced from the academic literature (see &lt;a href="https://arxiv.org/abs/2108.09020"&gt;1&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2201.06289"&gt;2&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2211.14238"&gt;3&lt;/a&gt;, &lt;a href="https://proceedings.mlr.press/v206/awasthi23b/awasthi23b.pdf"&gt;4&lt;/a&gt; for details) that spans data sources and modalities (photos, satellite images, social media text, medical records, sensor readings, tabular data) and sizes (ranging from 10k to 39M instances). We report significant gains in the test period when compared to the nearest published benchmark method for each dataset (shown below). Note that the previous best-known method may be different for each dataset. These results showcase the broad applicability of our approach.
&lt;/p&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhw95hIflfZ4eiNddWi0-YXONJYbMLT2yHp_Ekzm8v5e1WHpxeT5v7k21EYihoAqrplmlrtM76iiHjuBWtMQDbtj7TvtwIU0eZb44_QSeEe5U4k_z70y_9SsS3If8Y5xkMXKQYI5VzaTafWC7nVv5MgvNw_yL8HA6N7-gUPGGcJI2qtgKTcnqn2oN1ruBt-/s765/image7.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="552" data-original-width="765" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhw95hIflfZ4eiNddWi0-YXONJYbMLT2yHp_Ekzm8v5e1WHpxeT5v7k21EYihoAqrplmlrtM76iiHjuBWtMQDbtj7TvtwIU0eZb44_QSeEe5U4k_z70y_9SsS3If8Y5xkMXKQYI5VzaTafWC7nVv5MgvNw_yL8HA6N7-gUPGGcJI2qtgKTcnqn2oN1ruBt-/s16000/image7.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Performance gain of our method on a variety of tasks studying natural concept drift. Our reported gains are over the previous best-known method for each dataset.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br&gt;

&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h3&gt;Extensions to continual learning&lt;/h3&gt;


&lt;p&gt;
Finally, we consider an interesting extension of our work. The work above described how offline learning can be extended to handle concept drift using ideas inspired by continual learning. However, sometimes offline learning is infeasible — for example, if the amount of training data available is too large to maintain or process. We adapted our approach to continual learning in a straightforward manner by applying temporal reweighting &lt;em&gt;within the context of &lt;/em&gt;each bucket of data being used to sequentially update the model. This proposal still retains some limitations of continual learning, e.g., model updates are performed only on most-recent data, and all optimization decisions (including our reweighting) are only made over that data. Nevertheless, our approach consistently beats regular continual learning as well as a wide range of other continual learning algorithms on the photo categorization benchmark (see below). Since our approach is complementary to the ideas in many baselines compared here, we anticipate even larger gains when combined with them.
&lt;/p&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtWaqgT_9wt2sckjfrLbQ8LhRK5gL1yTowCf0h2nMnHhBYqfKP7VBwWfbK-5Y5zbYXiKoaF0TKve71FWrHazA4g4SPFD3leb56aZHex95MM_yovx2Y_uO4c5rOA5GzTndUGyBO4HH0gL3jYd8Jk4oPbi4HuSYDuMkKY5kPlqsb0s-re13QKfei2IrMig6S/s800/image6.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="600" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtWaqgT_9wt2sckjfrLbQ8LhRK5gL1yTowCf0h2nMnHhBYqfKP7VBwWfbK-5Y5zbYXiKoaF0TKve71FWrHazA4g4SPFD3leb56aZHex95MM_yovx2Y_uO4c5rOA5GzTndUGyBO4HH0gL3jYd8Jk4oPbi4HuSYDuMkKY5kPlqsb0s-re13QKfei2IrMig6S/s16000/image6.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Results of our method adapted to continual learning, compared to the latest baselines.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
  
 &lt;br&gt; 
&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;  
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
We addressed the challenge of data drift in learning by combining the strengths of previous approaches — offline learning with its effective reuse of data, and continual learning with its emphasis on more recent data. We hope that our work helps improve model robustness to concept drift in practice, and generates increased interest and new ideas in addressing the ubiquitous problem of slow concept drift.
&lt;/p&gt;


&lt;div style="line-height:40%;"&gt;
    &lt;br&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We thank Mike Mozer for many interesting discussions in the early phase of this work, as well as very helpful advice and feedback during its development.&lt;/em&gt;
&lt;/p&gt;
</content><link href="http://blog.research.google/feeds/5605933033299261025/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/learning-importance-of-training-data.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5605933033299261025" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5605933033299261025" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/learning-importance-of-training-data.html" rel="alternate" title="Learning the importance of training data under concept drift" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgUeskw4YD6cFTpLaRnv7OwMsljyeipfAb1riYxIuBsiWd6TBmUXMJ4QoI9tlvUzWX9NzBbEjz3-P2Zl2kuXe5BrVclmqQFrLButoya5phiEELq1azrhsIaGaCz-ov_jXaMsFrGRDE0EjotyRQPOX3xV5MAkVJfKp9xecX4t2CoLBiZ8r2RpZ25Y5KRitFG/s72-c/temporalreweightinghero.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-5933365460125094774</id><published>2024-02-13T14:11:00.000-08:00</published><updated>2024-02-13T14:11:49.258-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Differential Privacy"/><category scheme="http://www.blogger.com/atom/ns#" term="Responsible AI"/><category scheme="http://www.blogger.com/atom/ns#" term="Security and Privacy"/><title type="text">DP-Auditorium: A flexible library for auditing differential privacy</title><content type="html">&lt;span class="byline-author"&gt;Posted by Mónica Ribero Díaz, Research Scientist, Google Research&lt;/span&gt;


&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhNVpxjk-jj1rIYQ8AM3A-Syqxd3d8L8-wIy8NWwyobCXmTRK7mY9h94aJYgFCiC0gnehVFFoM8-in8HsOZjfhoNce03nbsrN5fxY07wADV6ULPC0POGmCc-8eL3OqA9KrDyzQxN38JKvh6xCmLV6FZ1g0UfaXtKORhtTy0WuJexlPqV6P2c9rPdg_W_5zP/s320/hero.jpg" style="display: none;" /&gt;

&lt;p&gt;
&lt;a href="https://en.wikipedia.org/wiki/Differential_privacy"&gt;Differential privacy&lt;/a&gt; (DP) is a property of randomized mechanisms that limit the influence of any individual user’s information while processing and analyzing data. DP offers a robust solution to address growing concerns about data protection, enabling technologies &lt;a href="https://blog.research.google/2022/02/federated-learning-with-formal.html"&gt;across&lt;/a&gt; &lt;a href="https://www.apple.com/privacy/docs/Differential_Privacy_Overview.pdf"&gt;industries&lt;/a&gt; and government applications (e.g., &lt;a href="https://www.census.gov/programs-surveys/decennial-census/decade/2020/planning-management/process/disclosure-avoidance/differential-privacy.html"&gt;the US census&lt;/a&gt;) without compromising individual user identities.  As its adoption increases, it’s important to identify the potential risks of developing mechanisms with faulty implementations. Researchers have recently found errors in the mathematical proofs of private mechanisms, and their implementations. For example, &lt;a href="https://arxiv.org/pdf/1603.01699.pdf"&gt;researchers compared&lt;/a&gt; six sparse vector technique (SVT) variations and found that only two of the six actually met the asserted privacy guarantee. Even when mathematical proofs are correct, the code implementing the mechanism is vulnerable to human error.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
However, practical and efficient DP auditing is challenging primarily due to the inherent randomness of the mechanisms and the probabilistic nature of the tested guarantees. In addition, a range of guarantee types exist, (e.g., &lt;a href="https://dl.acm.org/doi/10.1007/11681878_14"&gt;pure DP&lt;/a&gt;, &lt;a href="https://link.springer.com/chapter/10.1007/11761679_29"&gt;approximate DP&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/1702.07476"&gt;Rényi DP&lt;/a&gt;, and &lt;a href="https://arxiv.org/pdf/1603.01887.pdf"&gt;concentrated DP&lt;/a&gt;), and this diversity contributes to the complexity of formulating the auditing problem. Further, debugging mathematical proofs and code bases is an intractable task given the volume of proposed mechanisms. While &lt;em&gt;ad hoc&lt;/em&gt; testing techniques exist under specific assumptions of mechanisms, few efforts have been made to develop an extensible tool for testing DP mechanisms. 
&lt;/p&gt;

&lt;p&gt;
To that end, in “&lt;a href="https://arxiv.org/abs/2307.05608"&gt;DP-Auditorium: A Large Scale Library for Auditing Differential Privacy&lt;/a&gt;”, we introduce an &lt;a href="https://github.com/google/differential-privacy/tree/main/python/dp_auditorium"&gt;open source library&lt;/a&gt; for auditing DP guarantees with only black-box access to a mechanism (i.e., without any knowledge of the mechanism’s internal properties). DP-Auditorium is implemented in Python and provides a flexible interface that allows contributions to continuously improve its testing capabilities. We also introduce new testing algorithms that perform divergence optimization over function spaces for Rényi DP, pure DP, and approximate DP. We demonstrate that DP-Auditorium can efficiently identify DP guarantee violations, and suggest which tests are most suitable for detecting particular bugs under various privacy guarantees.
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;DP guarantees&lt;/h2&gt;


&lt;p&gt;
The output of a DP mechanism is a sample drawn from a probability distribution (&lt;em&gt;M&lt;/em&gt; (&lt;em&gt;D&lt;/em&gt;)) that satisfies a mathematical property ensuring the privacy of user data. A DP guarantee is thus tightly related to properties between pairs of probability distributions. A mechanism is differentially private if the probability distributions determined by &lt;i&gt;M&lt;/i&gt; on dataset &lt;em&gt;D&lt;/em&gt; and a neighboring dataset &lt;em&gt;D’&lt;/em&gt;, which differ by only one record, are &lt;em&gt;&lt;a href="https://en.wikipedia.org/wiki/Computational_indistinguishability"&gt;indistinguishable&lt;/a&gt;&lt;/em&gt; under a given divergence metric. 
&lt;/p&gt;

&lt;p&gt;
For example, the classical &lt;a href="https://software.imdea.org/~federico/pubs/2013.ICALP.pdf"&gt;approximate DP&lt;/a&gt; definition states that a mechanism is approximately DP with parameters (&lt;em&gt;ε&lt;/em&gt;, &lt;em&gt;δ&lt;/em&gt;) if the &lt;a href="https://arxiv.org/pdf/1508.00335.pdf"&gt;hockey-stick divergence&lt;/a&gt; of order &lt;em&gt;e&lt;sup&gt;ε&lt;/sup&gt;&lt;/em&gt;, between &lt;em&gt;M&lt;/em&gt;(&lt;em&gt;D) &lt;/em&gt;and &lt;em&gt;M&lt;/em&gt;(&lt;em&gt;D’&lt;/em&gt;), is at most &lt;em&gt;δ&lt;/em&gt;. Pure DP is a special instance of approximate DP where &lt;em&gt;δ = 0&lt;/em&gt;. Finally, a mechanism is considered &lt;a href="https://arxiv.org/abs/1702.07476"&gt;Rényi DP&lt;/a&gt; with parameters (&lt;em&gt;&#120572;&lt;/em&gt;, &lt;em&gt;ε)&lt;/em&gt; if the &lt;a href="https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy"&gt;Rényi divergence&lt;/a&gt; of order &lt;em&gt;&#120572;&lt;/em&gt;, is at most &lt;em&gt;ε&lt;/em&gt; (where &lt;em&gt;ε&lt;/em&gt; is a small positive value). In these three definitions, &lt;em&gt;ε &lt;/em&gt;is not interchangeable but intuitively conveys the same concept; larger values of &lt;em&gt;ε&lt;/em&gt; imply larger divergences between the two distributions or less privacy, since the two distributions are easier to distinguish.  
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;DP-Auditorium&lt;/h2&gt;


&lt;p&gt;
DP-Auditorium comprises two main components: property testers and dataset finders. Property testers take samples from a mechanism evaluated on specific datasets as input and aim to identify privacy guarantee violations in the provided datasets. Dataset finders suggest datasets where the privacy guarantee may fail. By combining both components, DP-Auditorium enables (1) automated testing of diverse mechanisms and privacy definitions and, (2) detection of bugs in privacy-preserving mechanisms. We implement various private and non-private mechanisms, including simple mechanisms that compute the mean of records and more complex mechanisms, such as different SVT and  &lt;a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent"&gt;gradient descent&lt;/a&gt; mechanism variants. 
&lt;/p&gt;

&lt;p&gt;
&lt;strong&gt;Property testers&lt;/strong&gt; determine if evidence exists to reject the hypothesis that a given divergence between two probability distributions, &lt;em&gt;P&lt;/em&gt; and &lt;em&gt;Q&lt;/em&gt;, is bounded by a prespecified budget determined by the DP guarantee being tested. They compute a lower bound from samples from &lt;em&gt;P&lt;/em&gt; and &lt;em&gt;Q,&lt;/em&gt; rejecting the property if the lower bound value exceeds the expected divergence. No guarantees are provided if the result is indeed bounded. To test for a range of privacy guarantees, DP-Auditorium introduces three novel testers: (1) HockeyStickPropertyTester, (2) RényiPropertyTester, and (3) MMDPropertyTester. Unlike other approaches, these testers don’t depend on explicit histogram approximations of the tested distributions. They rely on variational representations of the hockey-stick divergence, Rényi divergence, and &lt;a href="https://jmlr.csail.mit.edu/papers/v13/gretton12a.html"&gt;maximum mean discrepancy&lt;/a&gt; (MMD) that enable the estimation of divergences through optimization over function spaces. As a baseline, we implement &lt;a href="https://arxiv.org/abs/1806.06427"&gt;HistogramPropertyTester&lt;/a&gt;, a commonly used approximate DP tester. While our three testers follow a similar approach, for brevity, we focus on the HockeyStickPropertyTester in this post.
&lt;/p&gt;

&lt;p&gt;
Given two neighboring datasets, &lt;em&gt;D&lt;/em&gt; and &lt;em&gt;D’&lt;/em&gt;, the HockeyStickPropertyTester finds a lower bound,&lt;i&gt;&lt;span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);"&gt;^&lt;/span&gt;δ&lt;/i&gt; &amp;nbsp;for the hockey-stick divergence between &lt;em&gt;M&lt;/em&gt;(&lt;em&gt;D) &lt;/em&gt;and &lt;em&gt;M&lt;/em&gt;(&lt;em&gt;D’&lt;/em&gt;) that holds with high probability. Hockey-stick divergence enforces that the two distributions &lt;em&gt;M&lt;/em&gt;(&lt;em&gt;D) &lt;/em&gt;and &lt;em&gt;M&lt;/em&gt;(&lt;em&gt;D’&lt;/em&gt;) are close under an approximate DP guarantee. Therefore, if a privacy guarantee claims that the hockey-stick divergence is at most &lt;em&gt;δ&lt;/em&gt;, and&lt;i&gt;&lt;span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);"&gt;^&lt;/span&gt;δ&lt;/i&gt;&amp;nbsp; &amp;gt; &lt;em&gt;δ&lt;/em&gt;, then with high probability the divergence is higher than what was promised on &lt;em&gt;D&lt;/em&gt; and &lt;em&gt;D’&lt;/em&gt; and the mechanism cannot satisfy the given approximate DP guarantee. The lower bound&lt;i&gt;&lt;span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);"&gt;^&lt;/span&gt;δ&lt;/i&gt;&amp;nbsp; is computed as an empirical and tractable counterpart of a variational formulation of the hockey-stick divergence (see &lt;a href="https://arxiv.org/pdf/2307.05608.pdf"&gt;the paper&lt;/a&gt; for more details). The accuracy of&lt;i&gt;&lt;span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);"&gt;^&lt;/span&gt;δ&lt;/i&gt;&amp;nbsp; increases with the number of samples drawn from the mechanism, but decreases as the variational formulation is simplified. We balance these factors in order to ensure that&lt;i&gt;&lt;span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);"&gt;^&lt;/span&gt;δ&lt;/i&gt;&amp;nbsp; is both accurate and easy to compute. 
&lt;/p&gt;

&lt;p&gt;
&lt;strong&gt;Dataset finders&lt;/strong&gt; use &lt;a href="https://arxiv.org/pdf/2207.13676.pdf"&gt;black-box optimization&lt;/a&gt; to find datasets &lt;em&gt;D&lt;/em&gt; and &lt;em&gt;D’&lt;/em&gt; that maximize&lt;i&gt;&lt;span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);"&gt;^&lt;/span&gt;δ&lt;/i&gt;, a lower bound on the divergence value &lt;em&gt;δ&lt;/em&gt;. Note that black-box optimization techniques are specifically designed for settings where deriving gradients for an objective function may be impractical or even impossible. These optimization techniques oscillate between exploration and exploitation phases to estimate the shape of the objective function and predict areas where the objective can have optimal values. In contrast, a full exploration algorithm, such as the &lt;a href="https://en.wikipedia.org/wiki/Hyperparameter_optimization#Grid_search"&gt;grid search method&lt;/a&gt;, searches over the full space of neighboring datasets &lt;em&gt;D&lt;/em&gt; and &lt;em&gt;D’&lt;/em&gt;. DP-Auditorium implements different dataset finders through the open sourced black-box optimization library &lt;a href="https://github.com/google/vizier"&gt;Vizier&lt;/a&gt;. 
&lt;/p&gt;
&lt;p&gt;
Running existing components on a new mechanism only requires defining the mechanism as a Python function that takes an array of data &lt;em&gt;D&lt;/em&gt; and a desired number of samples &lt;em&gt;n&lt;/em&gt; to be output by the mechanism computed on &lt;em&gt;D&lt;/em&gt;. In addition, we provide flexible wrappers for testers and dataset finders that allow practitioners to implement their own testing and dataset search algorithms.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Key results&lt;/h2&gt;


&lt;p&gt;
We assess the effectiveness of DP-Auditorium on  five private and nine non-private mechanisms with diverse output spaces. For each property tester, we repeat the test ten times on fixed datasets using different values of &lt;em&gt;ε&lt;/em&gt;, and report the number of times each tester identifies privacy bugs. While no tester consistently outperforms the others, we identify bugs that would be missed by previous techniques (HistogramPropertyTester). Note that the HistogramPropertyTester is not applicable to SVT mechanisms. 
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlLYAUJ1cew8xCQNyNMvggKZ2c2bd5uHLzUdLx3xVdn_TW4ZBwd5tCI6zVVvVjmOWKJanJ4vP4swXOzNpZ4388x-iwISjqAzxnDAgM8F4-HL5gHLAGs3AIuqhns-gNJfA_AT9lmAMvItLRDEP5OjHPRFRA6OldJrY6Yost66LZ8Zsif8wIw6Uhkfa4PkN7/s785/image22.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="409" data-original-width="785" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlLYAUJ1cew8xCQNyNMvggKZ2c2bd5uHLzUdLx3xVdn_TW4ZBwd5tCI6zVVvVjmOWKJanJ4vP4swXOzNpZ4388x-iwISjqAzxnDAgM8F4-HL5gHLAGs3AIuqhns-gNJfA_AT9lmAMvItLRDEP5OjHPRFRA6OldJrY6Yost66LZ8Zsif8wIw6Uhkfa4PkN7/s16000/image22.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Number of times each property tester finds the privacy violation for the tested non-private mechanisms. NonDPLaplaceMean and NonDPGaussianMean mechanisms are faulty implementations of the &lt;a href="https://en.wikipedia.org/wiki/Additive_noise_differential_privacy_mechanisms#Laplace_Mechanism"&gt;Laplace&lt;/a&gt; and &lt;a href="https://en.wikipedia.org/wiki/Additive_noise_differential_privacy_mechanisms#Gaussian_Mechanism"&gt;Gaussian&lt;/a&gt; mechanisms for computing the mean.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;


&lt;p&gt;
We also analyze the implementation of a &lt;a href="https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py"&gt;DP gradient descent algorithm&lt;/a&gt; (DP-GD) in TensorFlow that computes gradients of the loss function on private data. To preserve privacy, DP-GD employs a clipping mechanism to bound the &lt;a href="https://mathworld.wolfram.com/L2-Norm.html"&gt;l2-norm&lt;/a&gt; of the gradients by a value &lt;em&gt;G&lt;/em&gt;, followed by the addition of Gaussian noise. This implementation incorrectly assumes that the noise added has a scale of &lt;em&gt;G&lt;/em&gt;, while in reality, the scale is &lt;em&gt;sG&lt;/em&gt;, where &lt;em&gt;s&lt;/em&gt; is a positive scalar. This discrepancy leads to an approximate DP guarantee that holds only for values of &lt;em&gt;s&lt;/em&gt; greater than or equal to 1.
&lt;/p&gt;

&lt;p&gt;
We evaluate the effectiveness of property testers in detecting this bug and show that HockeyStickPropertyTester and RényiPropertyTester exhibit superior performance in identifying privacy violations, outperforming MMDPropertyTester and HistogramPropertyTester. Notably, these testers detect the bug even for values of &lt;em&gt;s&lt;/em&gt; as high as 0.6. It is worth highlighting that &lt;em&gt;s &lt;/em&gt;= 0.5 corresponds to a &lt;a href="https://github.com/tensorflow/privacy/blob/308cbda4db6ccad5d1e7d56248727274e4c0c79e/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py#L445C1-L446C1"&gt;common error&lt;/a&gt; in literature that involves missing a factor of two when accounting for the privacy budget &lt;em&gt;ε&lt;/em&gt;. DP-Auditorium successfully captures this bug as shown below. For more details see section 5.6 &lt;a href="https://arxiv.org/pdf/2303.00654.pdf"&gt;here&lt;/a&gt;.


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-pnMcLqTWv1vSIZWncvObk3acW_SkBS3Lp_KuspJPbGBSjlepwW0hTLkCgLA7yTgU35y-Kj4HC_ddRX1fXS6T_HoF5Na87cSIcdiTBAwHnQ1sQZV3pdir_SI5PuwT7HAMEYmQohCd7wI84bNjKSt4sUVdnk9dOAXtkxCUDgzd3KZs5r2G2Z4jIZR0-FJH/s836/image21.jpg" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="332" data-original-width="836" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-pnMcLqTWv1vSIZWncvObk3acW_SkBS3Lp_KuspJPbGBSjlepwW0hTLkCgLA7yTgU35y-Kj4HC_ddRX1fXS6T_HoF5Na87cSIcdiTBAwHnQ1sQZV3pdir_SI5PuwT7HAMEYmQohCd7wI84bNjKSt4sUVdnk9dOAXtkxCUDgzd3KZs5r2G2Z4jIZR0-FJH/s16000/image21.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Estimated divergences and test thresholds for different values of &lt;em&gt;s&lt;/em&gt; when testing DP-GD with the HistogramPropertyTester (&lt;strong&gt;left&lt;/strong&gt;) and the HockeyStickPropertyTester (&lt;strong&gt;right&lt;/strong&gt;).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;br /&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEibbce0TFnWcnJ4CoXPVVyuZrja_3JJTnBjsza7Ig-NibA14jHoh4TIuIhLRn9BgCdo_N4hSuft7Zpl3WgNjmteMUGkQ5xdjeFH2SzZlKmPR_PvXS-JeOIcwJO8J_h7SlR9_tknZ0fLbP2qOypalwVm-nZO118Oa67zgdi_VGc72tAzGKaYpGoWIl6p_ljD/s828/image20.jpg" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="333" data-original-width="828" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEibbce0TFnWcnJ4CoXPVVyuZrja_3JJTnBjsza7Ig-NibA14jHoh4TIuIhLRn9BgCdo_N4hSuft7Zpl3WgNjmteMUGkQ5xdjeFH2SzZlKmPR_PvXS-JeOIcwJO8J_h7SlR9_tknZ0fLbP2qOypalwVm-nZO118Oa67zgdi_VGc72tAzGKaYpGoWIl6p_ljD/s16000/image20.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Estimated divergences and test thresholds for different values of &lt;em&gt;s&lt;/em&gt; when testing DP-GD with the RényiPropertyTester (&lt;strong&gt;left&lt;/strong&gt;) and the MMDPropertyTester (&lt;strong&gt;right&lt;/strong&gt;)&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;


&lt;p&gt;
To test dataset finders, we compute the number of datasets explored before finding a privacy violation. On average, the majority of bugs are discovered in less than 10 calls to dataset finders. Randomized and exploration/exploitation methods are more efficient at finding datasets than grid search. For more details, see the &lt;a href="https://arxiv.org/abs/2307.05608"&gt;paper&lt;/a&gt;.
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
DP is one of the most powerful frameworks for data protection. However, proper implementation of DP mechanisms can be challenging and prone to errors that cannot be easily detected using traditional unit testing methods. A unified testing framework can help auditors, regulators, and academics ensure that private mechanisms are indeed private. 
&lt;/p&gt;

&lt;p&gt;
DP-Auditorium is a new approach to testing DP via divergence optimization over function spaces. Our results show that this type of function-based estimation consistently outperforms previous black-box access testers. Finally, we demonstrate that these function-based estimators allow for a better discovery rate of privacy bugs compared to histogram estimation. By &lt;a href="https://github.com/google/differential-privacy/tree/main/python/dp_auditorium"&gt;open sourcing&lt;/a&gt; DP-Auditorium, we aim to establish a standard for end-to-end testing of new differentially private algorithms.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;The work described here was done jointly with Andrés Muñoz Medina, William Kong and Umar Syed. We thank Chris Dibak and Vadym Doroshenko for helpful engineering support and interface suggestions for our library.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/5933365460125094774/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/dp-auditorium-flexible-library-for.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5933365460125094774" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5933365460125094774" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/dp-auditorium-flexible-library-for.html" rel="alternate" title="DP-Auditorium: A flexible library for auditing differential privacy" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhNVpxjk-jj1rIYQ8AM3A-Syqxd3d8L8-wIy8NWwyobCXmTRK7mY9h94aJYgFCiC0gnehVFFoM8-in8HsOZjfhoNce03nbsrN5fxY07wADV6ULPC0POGmCc-8eL3OqA9KrDyzQxN38JKvh6xCmLV6FZ1g0UfaXtKORhtTy0WuJexlPqV6P2c9rPdg_W_5zP/s72-c/hero.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-3264694155710647310</id><published>2024-02-06T11:17:00.000-08:00</published><updated>2024-02-06T11:17:53.968-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Graph Mining"/><category scheme="http://www.blogger.com/atom/ns#" term="Graphs"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="TensorFlow"/><title type="text">Graph neural networks in TensorFlow</title><content type="html">&lt;span class="byline-author"&gt;Posted by Dustin Zelle, Software Engineer, Google Research, and Arno Eigenwillig, Software Engineer, CoreML&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhcnTwrjg8cyZhVY1c-qi2ZEenIrDlkmlKlX0GsAuiKiIoxUu6i-phANh8tsCG4mUm5i-7t3zdLwuwn5DCcuQI5FKq-C3eibPnuqfoLuKFUsx-I3Ovim1Teps_JKiKZH7XqgHupnsOa2Y3peUgWcPNYG4ZIqA2_KQwxJpflo0WM6gNW8tXg5eDndiWx_dKK/s1600/TFGNN%20hero.gif" style="display: none;" /&gt;

&lt;p&gt;
Objects and their relationships are ubiquitous in the world around us, and relationships can be as important to understanding an object as its own attributes viewed in isolation — take for example transportation networks, production networks, knowledge graphs, or social networks. Discrete mathematics and computer science have a long history of formalizing such networks as &lt;em&gt;&lt;a href="https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)"&gt;graphs&lt;/a&gt;&lt;/em&gt;, consisting of &lt;em&gt;nodes&lt;/em&gt; connected by &lt;em&gt;edges&lt;/em&gt; in various irregular ways. Yet most machine learning (ML) algorithms allow only for regular and uniform relations between input objects, such as a grid of pixels, a sequence of words, or no relation at all. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;

&lt;p&gt;
&lt;a href="https://distill.pub/2021/gnn-intro/"&gt;Graph neural networks&lt;/a&gt;, or GNNs for short, have emerged as a powerful technique to leverage both the graph’s connectivity (as in the older algorithms &lt;a href="http://perozzi.net/projects/deepwalk/"&gt;DeepWalk&lt;/a&gt; and &lt;a href="https://snap.stanford.edu/node2vec/"&gt;Node2Vec&lt;/a&gt;) and the input features on the various nodes and edges. GNNs can make predictions for graphs as a whole (Does this molecule react in a certain way?), for individual nodes (What’s the topic of this document, given its citations?) or for potential edges (Is this product likely to be purchased together with that product?). Apart from making predictions about graphs, GNNs are a powerful tool used to bridge the chasm to more typical neural network use cases. They encode a graph's &lt;em&gt;discrete&lt;/em&gt;, &lt;em&gt;relational&lt;/em&gt; information in a &lt;em&gt;continuous&lt;/em&gt; way so that it can be included naturally in another deep learning system.
&lt;/p&gt;
&lt;p&gt;
We are excited to announce the release of &lt;a href="https://github.com/tensorflow/gnn"&gt;TensorFlow GNN 1.0&lt;/a&gt; (TF-GNN), a production-tested library for building GNNs at large scales. It supports both modeling and training in TensorFlow as well as the extraction of input graphs from huge data stores. TF-GNN is built from the ground up for heterogeneous graphs, where types of objects and relations are represented by distinct sets of nodes and edges. Real-world objects and their relations occur in distinct types, and TF-GNN's heterogeneous focus makes it natural to represent them.
&lt;/p&gt;
&lt;p&gt;
  Inside TensorFlow, such graphs are represented by objects of type &lt;code&gt;tfgnn.GraphTensor&lt;/code&gt;. This is a composite tensor type (a collection of tensors in one Python class) accepted as a &lt;a href="https://en.wikipedia.org/wiki/First-class_citizen"&gt;first-class citizen&lt;/a&gt; in &lt;code&gt;tf.data.Dataset&lt;/code&gt;, &lt;code&gt;tf.function&lt;/code&gt;, etc. It stores both the graph structure and its features attached to nodes, edges and the graph as a whole. Trainable transformations of GraphTensors can be defined as Layers objects in the high-level &lt;a href="https://www.tensorflow.org/guide/keras"&gt;Keras API&lt;/a&gt;, or directly using the &lt;code&gt;tfgnn.GraphTensor&lt;/code&gt; primitive.
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;GNNs: Making predictions for an object in context&lt;/h2&gt;


&lt;p&gt;
For illustration, let’s look at one typical application of TF-GNN: predicting a property of a certain type of node in a graph defined by cross-referencing tables of a huge database. For example, a citation database of Computer Science (CS) arXiv papers with one-to-many cites and many-to-one cited relationships where we would like to predict the subject area of each paper.
&lt;/p&gt;
&lt;p&gt;
Like most neural networks, a GNN is trained on a dataset of many labeled examples (~millions), but each training step consists only of a much smaller batch of training examples (say, hundreds). To scale to millions, the GNN gets trained on a stream of reasonably small subgraphs from the underlying graph. Each subgraph contains enough of the original data to compute the GNN result for the labeled node at its center and train the model. This process — typically referred to as subgraph sampling — is extremely consequential for GNN training. Most existing tooling accomplishes sampling in a batch way, producing static subgraphs for training. TF-GNN provides tooling to improve on this by sampling dynamically and interactively. 
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhE36FVnslwVrX4LjLgpe5NOcVgJ2WSHCaw64LT9pMhjhHOFt-1pjp1AhaXqjxfEODX04Buw93D1G36HOStu5_mWUEdNs0gZTa1c7MXJ6ir9DYOp_HCYpFMT5NZiBbHxNwvUmF-dwhN2rgKQX0CeFY25X9aFnoD0W7bzL_xtkDJFdP0guocAJDSOgBHIiZm/s800/image2.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="600" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhE36FVnslwVrX4LjLgpe5NOcVgJ2WSHCaw64LT9pMhjhHOFt-1pjp1AhaXqjxfEODX04Buw93D1G36HOStu5_mWUEdNs0gZTa1c7MXJ6ir9DYOp_HCYpFMT5NZiBbHxNwvUmF-dwhN2rgKQX0CeFY25X9aFnoD0W7bzL_xtkDJFdP0guocAJDSOgBHIiZm/s16000/image2.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Pictured, the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
TF-GNN 1.0 debuts a flexible Python API to configure dynamic or batch subgraph sampling at all relevant scales: interactively in a Colab notebook (like &lt;a href="https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb"&gt;this one&lt;/a&gt;), for efficient sampling of a small dataset stored in the main memory of a single training host, or distributed by &lt;a href="https://beam.apache.org/"&gt;Apache Beam&lt;/a&gt; for huge datasets stored on a network filesystem (up to hundreds of millions of nodes and billions of edges). For details, please refer to our user guides for &lt;a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/inmemory_sampler.md"&gt;in-memory&lt;/a&gt; and &lt;a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/beam_sampler.md"&gt;beam-based&lt;/a&gt; sampling, respectively.
&lt;/p&gt;
&lt;p&gt;
On those same sampled subgraphs, the GNN’s task is to compute a hidden (or latent) state at the root node; the hidden state aggregates and encodes the relevant information of the root node's neighborhood. One classical approach is &lt;a href="https://research.google/pubs/neural-message-passing-for-quantum-chemistry/"&gt;message-passing neural networks&lt;/a&gt;. In each round of message passing, nodes receive messages from their neighbors along incoming edges and update their own hidden state from them. After &lt;em&gt;n&lt;/em&gt; rounds, the hidden state of the root node reflects the aggregate information from all nodes within &lt;em&gt;n&lt;/em&gt; edges (pictured below for &lt;em&gt;n&lt;/em&gt; = 2). The messages and the new hidden states are computed by hidden layers of the neural network. In a heterogeneous graph, it often makes sense to use separately trained hidden layers for the different types of nodes and edges
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMrCrQ1SCcwhZfE33X46EifocYAmKCPXMVe1d4na1V6flQavJ_f_FKtnlQbe2vnvzbSEtx5mxJHZ2OlQbO9rsiEhiPLY1PKQOT-EwahobMIVC92PZJs8RroEuYswHCpEjjpwqPrpqzKsDgrNaiY4lM_E8NVnxVRsYn0PNxe3TghByKJpW9V_YRD0RnNnm4/s573/image1.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="511" data-original-width="573" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMrCrQ1SCcwhZfE33X46EifocYAmKCPXMVe1d4na1V6flQavJ_f_FKtnlQbe2vnvzbSEtx5mxJHZ2OlQbO9rsiEhiPLY1PKQOT-EwahobMIVC92PZJs8RroEuYswHCpEjjpwqPrpqzKsDgrNaiY4lM_E8NVnxVRsYn0PNxe3TghByKJpW9V_YRD0RnNnm4/s16000/image1.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Pictured, a simple message-passing neural network where, at each step, the node state is propagated from outer to inner nodes where it is pooled to compute new node states. Once the root node is reached, a final prediction can be made.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;p&gt;
The training setup is completed by placing an output layer on top of the GNN’s hidden state for the labeled nodes, computing the &lt;em&gt;loss &lt;/em&gt;(to measure the prediction error), and updating model weights by backpropagation, as usual in any neural network training. 
&lt;/p&gt;
&lt;p&gt;
Beyond supervised training (i.e., minimizing a loss defined by labels), GNNs can also be trained in an unsupervised way (i.e., without labels). This lets us compute a &lt;em&gt;continuous&lt;/em&gt; representation (or &lt;em&gt;embedding&lt;/em&gt;) of the &lt;em&gt;discrete&lt;/em&gt; graph structure of nodes and their features. These representations are then typically utilized in other ML systems. In this way, the discrete, relational information encoded by a graph can be included in more typical neural network use cases. TF-GNN supports a fine-grained specification of unsupervised objectives for heterogeneous graphs.
&lt;/p&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Building GNN architectures&lt;/h2&gt;


&lt;p&gt;
The TF-GNN library supports building and training GNNs at various levels of abstraction.
&lt;/p&gt;
&lt;p&gt;
At the highest level, users can take any of the predefined models bundled with the library that are expressed in Keras layers. Besides a small collection of models from the research literature, TF-GNN comes with a highly configurable model template that provides a curated selection of modeling choices that we have found to provide strong baselines on many of our in-house problems. The templates implement GNN layers; users need only to initialize the Keras layers.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMfB8QoX14UU1GEAmFFOP0cAj__zxa_MKzVSiJoak9cVLNdbbhrSxbIWhqQM3OYKA5lo7zW8sWr6-9utm-rw0808rBOE4Cbw7NZxcmifenvF6DCH4opWhVQJHR-MLGcFoNu_WpET5h1PZRdXMhjcyKgBg3NchNTPq6gWVVluzcQNaO5qtonVp5KnJRgUaD/s1400/TFGNN%20code1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="865" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMfB8QoX14UU1GEAmFFOP0cAj__zxa_MKzVSiJoak9cVLNdbbhrSxbIWhqQM3OYKA5lo7zW8sWr6-9utm-rw0808rBOE4Cbw7NZxcmifenvF6DCH4opWhVQJHR-MLGcFoNu_WpET5h1PZRdXMhjcyKgBg3NchNTPq6gWVVluzcQNaO5qtonVp5KnJRgUaD/s16000/TFGNN%20code1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;







&lt;p&gt;
At the lowest level, users can write a GNN model from scratch in terms of primitives for passing data around the graph, such as broadcasting data from a node to all its outgoing edges or pooling data into a node from all its incoming edges (e.g., computing the sum of incoming messages). TF-GNN’s graph data model treats nodes, edges and whole input graphs equally when it comes to features or hidden states, making it straightforward to express not only node-centric models like the MPNN discussed above but also more general forms of &lt;a href="https://arxiv.org/abs/1806.01261"&gt;GraphNets&lt;/a&gt;. This can, but need not, be done with Keras as a modeling framework on the top of core TensorFlow. For more details, and intermediate levels of modeling, see the TF-GNN &lt;a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md"&gt;user guide&lt;/a&gt; and &lt;a href="https://github.com/tensorflow/gnn/tree/main/tensorflow_gnn/models"&gt;model collection&lt;/a&gt;.
&lt;/p&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Training orchestration&lt;/h2&gt;


&lt;p&gt;
While advanced users are free to do custom model training, the &lt;a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/runner.md"&gt;TF-GNN Runner&lt;/a&gt; also provides a succinct way to orchestrate the training of Keras models in the common cases. A simple invocation may look like this:
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxRRMrWL-AyxpHeyAhffhApAzlq-u7FoZaDnZFlwRsoYCljzZNi0LmRDDMwZ7mkXeBK0oUFujf_TDD-zlTQcgnLGhPedfrJ2vVs-D5-RPZFWXaaRpOJIt-MH3N8Tj7NZy-SFXTjxjDrhHQY_HVUA3-_C8_xQjfRWBlO-dzcFzgUL6wynMWJhUM7z_MYKvF/s1400/TFGNN%20code2.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="508" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxRRMrWL-AyxpHeyAhffhApAzlq-u7FoZaDnZFlwRsoYCljzZNi0LmRDDMwZ7mkXeBK0oUFujf_TDD-zlTQcgnLGhPedfrJ2vVs-D5-RPZFWXaaRpOJIt-MH3N8Tj7NZy-SFXTjxjDrhHQY_HVUA3-_C8_xQjfRWBlO-dzcFzgUL6wynMWJhUM7z_MYKvF/s16000/TFGNN%20code2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
The Runner provides ready-to-use solutions for ML pains like distributed training and &lt;code&gt;tfgnn.GraphTensor&lt;/code&gt; padding for fixed shapes on Cloud TPUs. Beyond training on a single task (as shown above), it supports joint training on multiple (two or more) tasks in concert. For example, unsupervised tasks can be mixed with supervised ones to inform a final continuous representation (or embedding) with application specific inductive biases. Callers only need substitute the task argument with a mapping of tasks:
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg4GGpfZib5MAUnX7BRLywJC4xMVt9Tz8kSMhgyDGN5A-aS9k-gna_t0Fo3uxMaAb8gK0ovrOO3XkeSNZ3i24leBCNsALR2NU_MWI7M_s47p2bx-aviaUKy_DxDEkzndNYMI_52jcEmNKyJrqDFye3_PHaWJZz7MAQ1lVW-YpuWPOOYpSAfbrunU5q4M2ev/s1400/TFGNN%20code3.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="392" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg4GGpfZib5MAUnX7BRLywJC4xMVt9Tz8kSMhgyDGN5A-aS9k-gna_t0Fo3uxMaAb8gK0ovrOO3XkeSNZ3i24leBCNsALR2NU_MWI7M_s47p2bx-aviaUKy_DxDEkzndNYMI_52jcEmNKyJrqDFye3_PHaWJZz7MAQ1lVW-YpuWPOOYpSAfbrunU5q4M2ev/s16000/TFGNN%20code3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
Additionally, the TF-GNN Runner also includes an implementation of &lt;a href="https://www.tensorflow.org/tutorials/interpretability/integrated_gradients"&gt;integrated gradients&lt;/a&gt; for use in model attribution. Integrated gradients output is a GraphTensor with the same connectivity as the observed GraphTensor but its features replaced with gradient values where larger values contribute more than smaller values in the GNN prediction. Users can inspect gradient values to see which features their GNN uses the most. 
&lt;/p&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
In short, we hope TF-GNN will be useful to advance the application of GNNs in TensorFlow at scale and fuel further innovation in the field. If you’re curious to find out more, please try our &lt;a href="https://colab.sandbox.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb"&gt;Colab demo&lt;/a&gt; with the popular OGBN-MAG benchmark (in your browser, no installation required), browse the rest of our &lt;a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/overview.md"&gt;user guides and Colabs&lt;/a&gt;, or take a look at our &lt;a href="https://arxiv.org/abs/2207.03522"&gt;paper&lt;/a&gt;.
&lt;/p&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;The TF-GNN release 1.0 was developed by a collaboration between Google Research: Sami Abu-El-Haija, Neslihan Bulut, Bahar Fatemi, Johannes Gasteiger, Pedro Gonnet, Jonathan Halcrow, Liangze Jiang, Silvio Lattanzi, Brandon Mayer, Vahab Mirrokni, Bryan Perozzi, Anton Tsitsulin, Dustin Zelle, Google Core ML: Arno Eigenwillig, Oleksandr Ferludin, Parth Kothari, Mihir Paradkar, Jan Pfeifer, Rachael Tamakloe, and Google DeepMind:&lt;strong&gt; &lt;/strong&gt;Alvaro Sanchez-Gonzalez and Lisa Wang.&lt;/em&gt;
&lt;/p&gt;
</content><link href="http://blog.research.google/feeds/3264694155710647310/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/graph-neural-networks-in-tensorflow.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/3264694155710647310" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/3264694155710647310" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/graph-neural-networks-in-tensorflow.html" rel="alternate" title="Graph neural networks in TensorFlow" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhcnTwrjg8cyZhVY1c-qi2ZEenIrDlkmlKlX0GsAuiKiIoxUu6i-phANh8tsCG4mUm5i-7t3zdLwuwn5DCcuQI5FKq-C3eibPnuqfoLuKFUsx-I3Ovim1Teps_JKiKZH7XqgHupnsOa2Y3peUgWcPNYG4ZIqA2_KQwxJpflo0WM6gNW8tXg5eDndiWx_dKK/s72-c/TFGNN%20hero.gif" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-6956767920612914706</id><published>2024-02-02T11:07:00.000-08:00</published><updated>2024-02-07T16:05:00.722-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Google Cloud Platform"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><title type="text">A decoder-only foundation model for time-series forecasting</title><content type="html">&lt;span class="byline-author"&gt;Posted by Rajat Sen and Yichen Zhou, Google Research&lt;/span&gt;


&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgjLAVI4q3e6yNyTPTCFiLZVQfFm71GOX1TosHg_Sb8M6tVSO1hyphenhyphenZccOlufnqSuXP1rVWHmqHcely6fgW1vex4JdxenniJcaJ7TOomZolUFut8RUdxnOFZDrbt0hrIHkcrK7rl6cq5-kUuWGrOYqIirPAKtnf4vMDauPX4lFAz2PQjiqzqHxMna7eja9gOF/s320/hero.jpg" style="display: none;" /&gt;

&lt;p&gt;
&lt;a href="https://en.wikipedia.org/wiki/Time_series"&gt;Time-series&lt;/a&gt; forecasting is ubiquitous in various domains, such as retail, finance, manufacturing, healthcare and natural sciences. In retail use cases, for example, it has been observed that &lt;a href="https://www.mckinsey.com/featured-insights/artificial-intelligence/notes-from-the-ai-frontier-applications-and-value-of-deep-learning"&gt;improving demand forecasting accuracy&lt;/a&gt; can meaningfully reduce inventory costs and increase revenue. Deep learning (DL) models have emerged as a popular approach for forecasting rich, multivariate, time-series data because they have proven to perform well in a variety of settings (e.g., DL models performed well in the &lt;a href="https://www.sciencedirect.com/science/article/pii/S0169207021001874"&gt;M5 competition&lt;/a&gt;).
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
At the same time, there has been rapid progress in large foundation language models used for natural language processing (NLP) tasks, such as &lt;a href="https://en.wikipedia.org/wiki/Machine_translation"&gt;translation&lt;/a&gt;, &lt;a href="https://www.analyticsvidhya.com/blog/2023/09/retrieval-augmented-generation-rag-in-ai/"&gt;retrieval-augmented generation&lt;/a&gt;, and &lt;a href="https://en.wikipedia.org/wiki/Intelligent_code_completion"&gt;code completion&lt;/a&gt;. These models are trained on massive amounts of &lt;em&gt;textual &lt;/em&gt;data derived from a variety of sources like &lt;a href="https://commoncrawl.org/"&gt;common crawl&lt;/a&gt; and open-source code that allows them to identify patterns in languages. This makes them very powerful &lt;a href="https://en.wikipedia.org/wiki/Zero-shot_learning"&gt;zero-shot&lt;/a&gt; tools; for instance, &lt;a href="https://blog.google/products/bard/google-bard-try-gemini-ai/"&gt;when paired with retrieval&lt;/a&gt;, they can answer questions about and summarize current events.
&lt;/p&gt;

&lt;p&gt;
Despite DL-based forecasters largely &lt;a href="https://arxiv.org/abs/1704.04110"&gt;outperforming&lt;/a&gt; traditional methods and progress being made in &lt;a href="https://cloud.google.com/blog/products/ai-machine-learning/vertex-ai-forecasting"&gt;reducing training and inference costs&lt;/a&gt;, they face challenges: most DL architectures require &lt;a href="https://cloud.google.com/blog/products/ai-machine-learning/vertex-ai-forecasting"&gt;long and involved training and validation cycles&lt;/a&gt; before a customer can test the model on a new time-series. A foundation model for time-series forecasting, in contrast, can provide decent out-of-the-box forecasts on unseen time-series data with no additional training, enabling users to focus on refining forecasts for the actual downstream task like &lt;a href="https://en.wikipedia.org/wiki/Customer_demand_planning"&gt;retail demand planning&lt;/a&gt;.
&lt;/p&gt;

&lt;p&gt;
To that end, in “&lt;a href="https://arxiv.org/pdf/2310.10688.pdf"&gt;A decoder-only foundation model for time-series forecasting&lt;/a&gt;”, we introduce TimesFM, a single forecasting model pre-trained on a large time-series corpus of 100 billion real world time-points. Compared to the latest large language models (LLMs), TimesFM is much smaller (200M parameters), yet we show that even at such scales, its zero-shot performance on a variety of unseen datasets of different domains and temporal granularities come close to the state-of-the-art supervised approaches trained explicitly on these datasets. Later this year we plan to make this model available for external customers in &lt;a href="https://cloud.google.com/vertex-ai/docs/tabular-data/forecasting/train-model#aiplatform_create_training_pipeline_tabular_forecasting_sample-python"&gt;Google Cloud Vertex AI&lt;/a&gt;.
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;A decoder-only foundation model for time-series forecasting&lt;/h2&gt;


&lt;p&gt;
LLMs are usually trained in a &lt;a href="https://arxiv.org/pdf/1801.10198.pdf"&gt;decoder-only&lt;/a&gt; fashion that involves three steps. First, text is broken down into subwords called tokens. Then, the tokens are fed into stacked causal &lt;a href="https://arxiv.org/abs/1706.03762"&gt;transformer&lt;/a&gt; layers that produce an output corresponding to each input token (it cannot attend to future tokens). Finally, the output corresponding to the &lt;em&gt;i&lt;/em&gt;-th token summarizes all the information from previous tokens and predicts the (&lt;em&gt;i&lt;/em&gt;+1)-th token. During inference, the LLM generates the output one token at a time. For example, when prompted with “What is the capital of France?”, it might generate the token “The”, then condition on “What is the capital of France? The” to generate the next token “capital” and so on until it generates the complete answer: “The capital of France is Paris”.
&lt;/p&gt;

&lt;p&gt;
A foundation model for time-series forecasting should adapt to variable context (what we observe) and horizon (what we query the model to forecast) lengths, while having enough capacity to encode all patterns from a large pretraining dataset. Similar to LLMs, we use stacked transformer layers (self-attention and &lt;a href="https://en.wikipedia.org/wiki/Feedforward_neural_network"&gt;feedforward&lt;/a&gt; layers) as the main building blocks for the TimesFM model. In the context of time-series forecasting, we treat a patch (a group of contiguous time-points) as a token that was popularized by a recent &lt;a href="https://arxiv.org/abs/2211.14730"&gt;long-horizon forecasting work&lt;/a&gt;. The task then is to forecast the (&lt;em&gt;i&lt;/em&gt;+1)-th patch of time-points given the &lt;em&gt;i&lt;/em&gt;-th output at the end of the stacked transformer layers. 
&lt;/p&gt;

&lt;p&gt;
However, there are several key differences from language models. Firstly, we need a &lt;a href="https://en.wikipedia.org/wiki/Multilayer_perceptron"&gt;multilayer perceptron&lt;/a&gt; block with residual connections to convert a patch of time-series into a token that can be input to the transformer layers along with &lt;a href="https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/"&gt;positional encodings&lt;/a&gt; (PE). For that, we use a residual block similar to our prior work in &lt;a href="https://arxiv.org/abs/2304.08424"&gt;long-horizon forecasting&lt;/a&gt;. Secondly, at the other end, an output token from the stacked transformer can be used to predict a longer length of subsequent time-points than the input patch length, i.e., the output patch length can be larger than the input patch length.  
&lt;/p&gt;

&lt;p&gt;
Consider a time-series of length 512 time-points being used to train a TimesFM model with input patch length 32 and output patch length 128. During training, the model is simultaneously trained to use the first 32 time-points to forecast the next 128 time-points, the first 64 time-points to forecast time-points 65 to 192, the first 96 time-points to forecast time-points 97 to 224 and so on. During inference, suppose the model is given a new time-series of length 256 and tasked with forecasting the next 256 time-points into the future. The model will first generate the future predictions for time-points 257 to 384, then condition on the initial 256 length input plus the generated output to generate time-points 385 to 512. On the other hand, if in our model the output patch length was equal to the input patch length of 32 then for the same task we would have to go through eight generation steps instead of just the two above. This increases the chances of more errors accumulating and therefore, in practice, we see that a longer output patch length yields better performance for long-horizon forecasting
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj4G0lBOLUqlPIXJ3R68kjS984MBIKBPDBrCWtgmjVVTyQRqY6-rn3aHJjgxCbG-8csyBLsp0POILdeJ2VcsRy8lrip0k5DWsUpuL9LU1qOPXLW99mraNdd6HVU791NYqJeTyY7LjuMnOIo6RGmkxBQqqaPrSsC0dELrwy21QUs1Jgwxr8flmdNkDV2tZsT/s1084/image3.jpg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="674" data-original-width="1084" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj4G0lBOLUqlPIXJ3R68kjS984MBIKBPDBrCWtgmjVVTyQRqY6-rn3aHJjgxCbG-8csyBLsp0POILdeJ2VcsRy8lrip0k5DWsUpuL9LU1qOPXLW99mraNdd6HVU791NYqJeTyY7LjuMnOIo6RGmkxBQqqaPrSsC0dELrwy21QUs1Jgwxr8flmdNkDV2tZsT/s16000/image3.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;TimesFM architecture.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Pretraining data&lt;/h2&gt;


&lt;p&gt;
Just like LLMs get better with more tokens, TimesFM requires a large volume of legitimate time series data to learn and improve. We have spent a great amount of time creating and assessing our training datasets, and the following is what we have found works best:
&lt;/p&gt;




&lt;div style="margin-left: 40px;"&gt;
&lt;p&gt;

    &lt;strong&gt;Synthetic data helps with the basics.&lt;/strong&gt; Meaningful synthetic time-series data can be generated using statistical models or physical simulations. These basic temporal patterns can teach the model the grammar of time series forecasting.
&lt;/p&gt;&lt;/div&gt;
  
&lt;div style="margin-left: 40px;"&gt;
&lt;p&gt;  
    &lt;strong&gt;Real-world data adds real-world flavor.&lt;/strong&gt; We comb through available public time series datasets, and selectively put together a large corpus of 100 billion time-points. Among these datasets there are &lt;a href="https://trends.google.com/trends/"&gt;Google Trends&lt;/a&gt; and &lt;a href="https://meta.wikimedia.org/wiki/Research:Page_view"&gt;Wikipedia Pageviews&lt;/a&gt;, which track what people are interested in, and that nicely mirrors trends and patterns in many other real-world time series. This helps TimesFM understand the bigger picture and generalize better when provided with domain-specific contexts not seen during training.
&lt;/p&gt;&lt;/div&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Zero-shot evaluation results&lt;/h2&gt;


&lt;p&gt;
We evaluate TimesFM zero-shot on data not seen during training using popular time-series benchmarks. We observe that TimesFM performs better than most statistical methods like &lt;a href="https://en.wikipedia.org/wiki/Autoregressive_integrated_moving_average"&gt;ARIMA&lt;/a&gt;, &lt;a href="https://en.wikipedia.org/wiki/Exponential_smoothing"&gt;ETS&lt;/a&gt; and can match or outperform powerful DL models like &lt;a href="https://arxiv.org/abs/1704.04110"&gt;DeepAR&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2211.14730"&gt;PatchTST&lt;/a&gt; that have been &lt;em&gt;explicitly trained&lt;/em&gt; on the target time-series.
&lt;/p&gt;

&lt;p&gt;
We used the &lt;a href="https://huggingface.co/datasets/monash_tsf"&gt;Monash Forecasting Archive&lt;/a&gt; to evaluate TimesFM’s out-of-the-box performance. This archive contains tens of thousands of time-series from various domains like traffic, weather, and demand forecasting covering frequencies ranging from few minutes to yearly data. Following existing literature, we inspect the &lt;a href="https://en.wikipedia.org/wiki/Mean_absolute_error"&gt;mean absolute error&lt;/a&gt; (MAE) &lt;a href="https://arxiv.org/abs/2310.07820"&gt;appropriately scaled&lt;/a&gt; so that it can be averaged across the datasets. We see that zero-shot (ZS) TimesFM is better than most supervised approaches, including recent deep learning models. We also compare TimesFM to &lt;a href="https://platform.openai.com/docs/models/gpt-3-5"&gt;GPT-3.5&lt;/a&gt; for forecasting using a specific prompting technique proposed by &lt;a href="https://arxiv.org/abs/2310.07820"&gt;llmtime(ZS)&lt;/a&gt;. We demonstrate that TimesFM performs better than llmtime(ZS) despite being orders of magnitude smaller.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhIeNF6GcmbUvVvYpKxNSvwlm_swz6M3G7nTDl0INa2zq8AlvjTBCVuvwOw0dx48JCk4H3S0aBUcsvqj2BypV3340cblqgD6yktoLBXzpxA2fwoM4n_KU8m0TfaESjihc3nx29RYVTpO4g09RCK-rucPulH3gqEOU9jO7EZ_VbDcFnfB_RHXmdpuZO_T_-g/s1476/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="876" data-original-width="1476" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhIeNF6GcmbUvVvYpKxNSvwlm_swz6M3G7nTDl0INa2zq8AlvjTBCVuvwOw0dx48JCk4H3S0aBUcsvqj2BypV3340cblqgD6yktoLBXzpxA2fwoM4n_KU8m0TfaESjihc3nx29RYVTpO4g09RCK-rucPulH3gqEOU9jO7EZ_VbDcFnfB_RHXmdpuZO_T_-g/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Scaled MAE (the lower the better) of TimesFM(ZS) against other supervised and zero-shot approaches on Monash datasets.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;
&lt;br&gt;



&lt;p&gt;
Most of the Monash datasets are short or medium horizon, i.e., the prediction length is not too long. We also test TimesFM on popular benchmarks for long horizon forecasting against a recent state-of-the-art baseline &lt;a href="https://arxiv.org/abs/2211.14730"&gt;PatchTST&lt;/a&gt; (and other long-horizon forecasting baselines). In the next figure, we plot the MAE on &lt;a href="https://paperswithcode.com/dataset/ett"&gt;ETT&lt;/a&gt; datasets for the task of predicting 96 and 192 time-points into the future. The metric has been calculated on the last test window of each dataset (as done by the &lt;a href="https://arxiv.org/abs/2310.07820"&gt;llmtime&lt;/a&gt; paper). We see that TimesFM not only surpasses the performance of llmtime(ZS) but also matches that of the supervised PatchTST model explicitly trained on the respective datasets. 
&lt;/p&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj0DDM32GPO6zkmnIrObEP2OA92g45b-zSMHgCf-uNoj6Ed0M0zVsN7vmFmfgXT6Sh5p-W0xI1qj6YwXcqi3T6aD5hI9ZOJqT8Sobp43FGrtSsLUkI2poHnGml7Za4BMObSd6nEKUVL8wj7nHJDFYHbWaQOXOcfxvqXUcMxUZ3WVQW8Z5sabfFsi7M85_7I/s735/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="433" data-original-width="735" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj0DDM32GPO6zkmnIrObEP2OA92g45b-zSMHgCf-uNoj6Ed0M0zVsN7vmFmfgXT6Sh5p-W0xI1qj6YwXcqi3T6aD5hI9ZOJqT8Sobp43FGrtSsLUkI2poHnGml7Za4BMObSd6nEKUVL8wj7nHJDFYHbWaQOXOcfxvqXUcMxUZ3WVQW8Z5sabfFsi7M85_7I/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Last window MAE (the lower the better) of TimesFM(ZS) against llmtime(ZS) and long-horizon forecasting baselines on ETT datasets.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;br /&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
We train a decoder-only foundation model for time-series forecasting using a large pretraining corpus of 100B real world time-points, the majority of which was search interest time-series data derived from Google Trends and pageviews from Wikipedia. We show that even a relatively small 200M parameter pretrained model that uses our TimesFM architecture displays impressive zero-shot performance on a variety of public benchmarks from different domains and granularities. 
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;This work is the result of a collaboration between several individuals across Google Research and Google Cloud, including (in alphabetical order): Abhimanyu Das, Weihao Kong, Andrew Leach, Mike Lawrence, Alex Martin, Rajat Sen, Yang Yang, Skander Hannachi, Ivan Kuznetsov and Yichen Zhou.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/6956767920612914706/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/a-decoder-only-foundation-model-for.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/6956767920612914706" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/6956767920612914706" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/a-decoder-only-foundation-model-for.html" rel="alternate" title="A decoder-only foundation model for time-series forecasting" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgjLAVI4q3e6yNyTPTCFiLZVQfFm71GOX1TosHg_Sb8M6tVSO1hyphenhyphenZccOlufnqSuXP1rVWHmqHcely6fgW1vex4JdxenniJcaJ7TOomZolUFut8RUdxnOFZDrbt0hrIHkcrK7rl6cq5-kUuWGrOYqIirPAKtnf4vMDauPX4lFAz2PQjiqzqHxMna7eja9gOF/s72-c/hero.jpg" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-2254287928040727502</id><published>2024-02-02T09:49:00.000-08:00</published><updated>2024-02-02T09:49:36.211-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Deep Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="ICML"/><category scheme="http://www.blogger.com/atom/ns#" term="ML Fairness"/><category scheme="http://www.blogger.com/atom/ns#" term="Supervised Learning"/><title type="text">Intervening on early readouts for mitigating spurious features and simplicity bias</title><content type="html">&lt;span class="byline-author"&gt;Posted by Rishabh Tiwari, Pre-doctoral Researcher, and Pradeep Shenoy, Research Scientist, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgdBd5rMRA2U1nd8fetuEweTgmHncn49ASMQtPlm6dfsr5V29RwsoUR8UtK4B7oSE1eiIdW-vD-gjCUK4tGZTbsY4XdO0adL2YtAjpgbF1S3mL_Jw3f31SwLKYUtCOLJ807gdXdRmD5iVsrtc_Ii-BiqQacv89vbtRbNAIINa9PhKAF_sDAZu09FLs4599T/s1600/SiFer%20Hero.png" style="display: none;" /&gt;

&lt;p&gt;
Machine learning models in the real world are often trained on limited data that may contain unintended &lt;a href="https://en.wikipedia.org/wiki/Bias_(statistics)"&gt;statistical biases&lt;/a&gt;. For example, in the &lt;a href="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html"&gt;CELEBA&lt;/a&gt; celebrity image dataset, a disproportionate number of female celebrities have blond hair, leading to classifiers incorrectly predicting “blond” as the hair color for most female faces — here, gender is a spurious feature for predicting hair color. Such unfair biases could have significant consequences in critical applications such as &lt;a href="https://www.researchgate.net/publication/362524426_Addressing_fairness_in_artificial_intelligence_for_medical_imaging"&gt;medical diagnosis&lt;/a&gt;. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;


&lt;p&gt;
Surprisingly, recent work has also discovered an inherent tendency of deep networks to &lt;em&gt;amplify such statistical biases&lt;/em&gt;, through the so-called &lt;a href="https://proceedings.neurips.cc/paper/2020/file/6cfe0e6127fa25df2a0ef2ae1067d915-Paper.pdf"&gt;simplicity bias&lt;/a&gt; of deep learning. This bias is the tendency of deep networks to identify weakly predictive features early in the training, and continue to anchor on these features, failing to identify more complex and potentially more accurate features. 
&lt;/p&gt;
&lt;p&gt;
With the above in mind, we propose simple and effective fixes to this dual challenge of spurious features and simplicity bias by applying &lt;em&gt;early readouts&lt;/em&gt; and &lt;em&gt;feature forgetting&lt;/em&gt;. First, in “&lt;a href="https://arxiv.org/abs/2310.18590"&gt;Using Early Readouts to Mediate Featural Bias in Distillation&lt;/a&gt;”, we show that making predictions from early layers of a deep network (referred to as “early readouts”) can automatically signal issues with the quality of the learned representations. In particular, these predictions are more often wrong, and more confidently wrong, when the network is relying on spurious features. We use this erroneous confidence to improve outcomes in &lt;a href="https://arxiv.org/pdf/1503.02531.pdf"&gt;model distillation&lt;/a&gt;, a setting where a larger “teacher” model guides the training of a smaller “student” model. Then in “&lt;a href="https://arxiv.org/abs/2301.13293"&gt;Overcoming Simplicity Bias in Deep Networks using a Feature Sieve&lt;/a&gt;”, we intervene directly on these indicator signals by making the network “forget” the problematic features and consequently look for better, more predictive features. This substantially improves the model’s ability to generalize to unseen domains compared to previous approaches. Our &lt;a href="https://ai.google/responsibility/principles"&gt;AI Principles&lt;/a&gt; and our &lt;a href="https://ai.google/responsibility/responsible-ai-practices/"&gt;Responsible AI practices&lt;/a&gt; guide how we research and develop these advanced applications and help us address the challenges posed by statistical biases.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhzG_p8Re7HHeTp_Qg_GwjX5LcHsE-TZDmHr3azTSOLKl4f1J4xcL9vxo46zicAl6QoIKIrTJaI2Z51iFq2oICjeb6Ut4-W1W74bytv87pH3hKVJOotWWWDk0gwB-ak_YZRmtZyimw8b9lSJ1DRzh6uIpvIBN2pbIw-6MuN47rUjTK_RzLLfYXPrIjtpjRz/s1080/image3.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="540" data-original-width="1080" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhzG_p8Re7HHeTp_Qg_GwjX5LcHsE-TZDmHr3azTSOLKl4f1J4xcL9vxo46zicAl6QoIKIrTJaI2Z51iFq2oICjeb6Ut4-W1W74bytv87pH3hKVJOotWWWDk0gwB-ak_YZRmtZyimw8b9lSJ1DRzh6uIpvIBN2pbIw-6MuN47rUjTK_RzLLfYXPrIjtpjRz/s16000/image3.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Animation comparing hypothetical responses from two models trained with and without the feature sieve.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Early readouts for debiasing distillation&lt;/h2&gt;


&lt;p&gt;
We first illustrate the diagnostic value of &lt;em&gt;early readouts&lt;/em&gt; and their application in debiased distillation, i.e., making sure that the student model inherits the teacher model’s resilience to feature bias through distillation. We start with a standard distillation framework where the student is trained with a mixture of label matching (minimizing the &lt;a href="https://towardsdatascience.com/cross-entropy-loss-function-f38c4ec8643e"&gt;cross-entropy loss&lt;/a&gt; between student outputs and the ground-truth labels) and teacher matching (minimizing the &lt;a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence"&gt;KL divergence&lt;/a&gt; loss between student and teacher outputs for any given input). 
&lt;/p&gt;
&lt;p&gt;
Suppose one trains a linear decoder, i.e., a small auxiliary neural network named as &lt;em&gt;Aux,&lt;/em&gt; on top of an intermediate representation of the student model. We refer to the output of this linear decoder as an early readout of the network representation. Our finding is that early readouts make more errors on instances that contain spurious features, and further, the confidence on those errors is higher than the confidence associated with other errors. This suggests that confidence on errors from early readouts is a fairly strong, automated indicator of the model’s dependence on potentially spurious features.
&lt;/p&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEixpq4OhPGxL9gGW30-0kqQ_CieDj3PJcqw8L4_7fBDZOFKuQpI67ljqIItOoJ3U9-dpPd1CpofAG_ld689r0HcPTrzFeTd1ceMQ42C3CRPWWJMYknydHpJhFjQUjb-M6mx8ILQbWEBIOv-NSgTauMGgDZ8t3EMGHE3j6UN9HIF3BJmB63GhOzFwOVmswlc/s1128/image5.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="796" data-original-width="1128" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEixpq4OhPGxL9gGW30-0kqQ_CieDj3PJcqw8L4_7fBDZOFKuQpI67ljqIItOoJ3U9-dpPd1CpofAG_ld689r0HcPTrzFeTd1ceMQ42C3CRPWWJMYknydHpJhFjQUjb-M6mx8ILQbWEBIOv-NSgTauMGgDZ8t3EMGHE3j6UN9HIF3BJmB63GhOzFwOVmswlc/s16000/image5.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Illustrating the usage of early readouts (i.e., output from the auxiliary layer) in debiasing distillation. Instances that are confidently mispredicted in the early readouts are upweighted in the distillation loss.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;





&lt;p&gt;
We used this signal to modulate the contribution of the teacher in the distillation loss on a per-instance basis, and found significant improvements in the trained student model as a result.
&lt;/p&gt;
&lt;p&gt;
We evaluated our approach on standard benchmark datasets known to contain spurious correlations (&lt;a href="https://arxiv.org/pdf/1911.08731.pdf"&gt;Waterbirds&lt;/a&gt;, &lt;a href="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html"&gt;CelebA&lt;/a&gt;, &lt;a href="https://www.tensorflow.org/datasets/catalog/civil_comments"&gt;CivilComments&lt;/a&gt;, &lt;a href="https://cims.nyu.edu/~sbowman/multinli/"&gt;MNLI&lt;/a&gt;). Each of these datasets contain groupings of data that share an attribute potentially correlated with the label in a spurious manner. As an example, the CelebA dataset mentioned above includes groups such as {blond male, blond female, non-blond male, non-blond female}, with models typically performing the worst on the {non-blond female} group when predicting hair color. Thus, a measure of model performance is its &lt;em&gt;worst group accuracy&lt;/em&gt;, i.e., the lowest accuracy among all known groups present in the dataset. We improved the worst group accuracy of student models on all datasets; moreover, we also improved overall accuracy in three of the four datasets, showing that our improvement on any one group does not come at the expense of accuracy on other groups. More details are available in our &lt;a href="https://arxiv.org/pdf/2310.18590.pdf"&gt;paper&lt;/a&gt;.
&lt;/p&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpQiz04rM3DMtDiusAWyWl92FMUKbafR0l2dGvrj17fX3nuvPDnyXMQaumsxDvch3ScnOCL4Duq5_O32dWbv_CTsIu5aNc-c3xrVAIXjQ3kmn0jZ_TZ5SJ7C2lq1oxLZ33-VKXSSPRa_oGUB5jJlsBTZupsHMeUtSVXLh414e1NVEgI1IamqhTA1dqU0s5/s1270/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1014" data-original-width="1270" height="511" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpQiz04rM3DMtDiusAWyWl92FMUKbafR0l2dGvrj17fX3nuvPDnyXMQaumsxDvch3ScnOCL4Duq5_O32dWbv_CTsIu5aNc-c3xrVAIXjQ3kmn0jZ_TZ5SJ7C2lq1oxLZ33-VKXSSPRa_oGUB5jJlsBTZupsHMeUtSVXLh414e1NVEgI1IamqhTA1dqU0s5/w640-h511/image4.png" width="640" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparison of Worst Group Accuracies of different distillation techniques relative to that of the Teacher model. Our method outperforms other methods on all datasets.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Overcoming simplicity bias with a feature sieve&lt;/h2&gt;


&lt;p&gt;
In a second, closely related project, we intervene directly on the information provided by early readouts, to improve &lt;a href="https://en.wikipedia.org/wiki/Feature_learning"&gt;feature learning&lt;/a&gt; and &lt;a href="https://developers.google.com/machine-learning/crash-course/generalization/video-lecture"&gt;generalization&lt;/a&gt;. The workflow alternates between &lt;em&gt;identifying &lt;/em&gt;problematic features and &lt;em&gt;erasing identified features&lt;/em&gt; from the network. Our primary hypothesis is that early features are more prone to simplicity bias, and that by erasing (“sieving”) these features, we allow richer feature representations to be learned.  
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEghN4NJ5vZ6jESH3koLTfGa3DpSenk5liLEg2awv2cOo1blDwwuDjLGVGxyeHSAzkLWTBUwO_swf4uGC2oShnD0WTNrebCL9KLAMOBIxR3ZZnw9eVS8g16s_lgP5kCbhZmVoTctASyDVvb3wtzIlzju01m4ADr7G21NpOWpac55hBllzYBaQVAXCjq8BIca/s1098/image6.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="604" data-original-width="1098" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEghN4NJ5vZ6jESH3koLTfGa3DpSenk5liLEg2awv2cOo1blDwwuDjLGVGxyeHSAzkLWTBUwO_swf4uGC2oShnD0WTNrebCL9KLAMOBIxR3ZZnw9eVS8g16s_lgP5kCbhZmVoTctASyDVvb3wtzIlzju01m4ADr7G21NpOWpac55hBllzYBaQVAXCjq8BIca/s16000/image6.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Training workflow with feature sieve. We alternate between identifying problematic features (using training iteration) and erasing them from the network (using forgetting iteration).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
We describe the identification and erasure steps in more detail: 
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;&lt;b&gt;Identifying simple features&lt;/b&gt;:  We train the primary model and the readout model (AUX above) in conventional fashion via forward- and back-propagation. Note that feedback from the auxiliary layer does not back-propagate to the main network. This is to force the auxiliary layer to learn from already-available features rather than create or reinforce them in the main network. 

&lt;/li&gt;&lt;li&gt;&lt;b&gt;Applying the feature sieve&lt;/b&gt;: We aim to erase the identified features in the early layers of the neural network with the use of a novel &lt;em&gt;forgetting loss&lt;/em&gt;,&lt;em&gt; L&lt;sub&gt;f &lt;/sub&gt;&lt;/em&gt;, which is simply the cross-entropy between the readout and a uniform distribution over labels. Essentially, all information that leads to nontrivial readouts are erased from the primary network. In this step, the auxiliary network and upper layers of the main network are kept unchanged.
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;
We can control specifically how the feature sieve is applied to a given dataset through a small number of configuration parameters. By changing the position and complexity of the auxiliary network, we control the complexity of the identified- and erased features. By modifying the mixing of learning and forgetting steps, we control the degree to which the model is challenged to learn more complex features. These choices, which are dataset-dependent, are made via &lt;a href="https://en.wikipedia.org/wiki/Hyperparameter_optimization"&gt;hyperparameter search&lt;/a&gt; to maximize validation accuracy, a  standard measure of generalization. Since we include “no-forgetting” (i.e., the baseline model) in the search space, we expect to find settings that are at least as good as the baseline.
&lt;/p&gt;
&lt;p&gt;
Below we show features learned by the baseline model (middle row) and our model (bottom row) on two benchmark datasets — biased activity recognition (&lt;a href="https://github.com/alinlab/BAR"&gt;BAR&lt;/a&gt;) and animal categorization (&lt;a href="https://arxiv.org/pdf/1906.02899v3.pdf"&gt;NICO&lt;/a&gt;). Feature importance was estimated using post-hoc gradient-based importance scoring (&lt;a href="https://arxiv.org/abs/1610.02391"&gt;GRAD-CAM&lt;/a&gt;), with the orange-red end of the spectrum indicating high importance, while green-blue indicates low importance. Shown below, our trained models focus on the primary object of interest, whereas the baseline model tends to focus on background features that are simpler and spuriously correlated with the label. 
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgumwu2DQ-nPeTLxt_uS6q6tIR6oQZdlWOoM4_I5kUmYfyJi8xyWIpw7WusdRAsA_YthYgO2Zz8sj7V1Id3JOTsljM9zpK2vwhokMfnZQOxbAIWtaFvFN4sfN6qF0rkOklj10y-_rLfL-WQS4zf6AWCub7aUTS7a8LyEsZ5uhQmXjTai7neuWElZBbP_5UI/s1616/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="850" data-original-width="1616" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgumwu2DQ-nPeTLxt_uS6q6tIR6oQZdlWOoM4_I5kUmYfyJi8xyWIpw7WusdRAsA_YthYgO2Zz8sj7V1Id3JOTsljM9zpK2vwhokMfnZQOxbAIWtaFvFN4sfN6qF0rkOklj10y-_rLfL-WQS4zf6AWCub7aUTS7a8LyEsZ5uhQmXjTai7neuWElZBbP_5UI/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Feature importance scoring using GRAD-CAM on activity recognition (BAR) and animal categorization (NICO) generalization benchmarks. Our approach (last row) focuses on the relevant objects in the image, whereas the baseline (ERM; middle row) relies on background features that are spuriously correlated with the label.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
Through this ability to learn better, generalizable features, we show substantial gains over a range of relevant baselines on real-world spurious feature benchmark datasets: &lt;a href="https://github.com/alinlab/BAR"&gt;BAR&lt;/a&gt;, &lt;a href="https://arxiv.org/pdf/2104.06885.pdf"&gt;CelebA Hair&lt;/a&gt;, &lt;a href="https://nico.thumedialab.com/"&gt;NICO&lt;/a&gt; and &lt;a href="https://www.tensorflow.org/datasets/catalog/imagenet_a"&gt;ImagenetA&lt;/a&gt;, by margins up to 11% (see figure below). More details are available in &lt;a href="https://arxiv.org/abs/2301.13293"&gt;our paper&lt;/a&gt;.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjjuXHls8mwfL2u-TVZlDlu5UMPrank9F2ODbf6h12q9oMLNrIYyfyv4OuQriS0XzI-z0BrQOs2xUiXt53lGLQtdzmKQDtGXFtv6TZEGg4pKua8JD9AkQn0J92mTjlQAlZTUPgqIYRAFpnsRTU0szE5J90_LeGNj3PTUKrsgq3WAMAjWSy30HQtMnNzevvY/s1082/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1082" data-original-width="844" height="640" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjjuXHls8mwfL2u-TVZlDlu5UMPrank9F2ODbf6h12q9oMLNrIYyfyv4OuQriS0XzI-z0BrQOs2xUiXt53lGLQtdzmKQDtGXFtv6TZEGg4pKua8JD9AkQn0J92mTjlQAlZTUPgqIYRAFpnsRTU0szE5J90_LeGNj3PTUKrsgq3WAMAjWSy30HQtMnNzevvY/w501-h640/image1.png" width="501" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Our feature sieve method improves accuracy by significant margins relative to the nearest baseline for a range of feature generalization benchmark datasets.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
We hope that our work on early readouts and their use in feature sieving for generalization will both spur the development of a new class of adversarial feature learning approaches and help improve the generalization capability and robustness of deep learning systems. 
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements &lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;The work on applying early readouts to debiasing distillation was conducted in collaboration with our academic partners Durga Sivasubramanian, Anmol Reddy and Prof. Ganesh Ramakrishnan at &lt;a href="https://www.iitb.ac.in/"&gt;IIT Bombay&lt;/a&gt;. We extend our sincere gratitude to Praneeth Netrapalli and Anshul Nasery for their feedback and recommendations. We are also grateful to Nishant Jain, Shreyas Havaldar, Rachit Bansal, Kartikeya Badola, Amandeep Kaur and the whole cohort of pre-doctoral researchers at Google Research India for taking part in research discussions. Special thanks to Tom Small for creating the animation used in this post.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/2254287928040727502/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/intervening-on-early-readouts-for.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/2254287928040727502" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/2254287928040727502" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/02/intervening-on-early-readouts-for.html" rel="alternate" title="Intervening on early readouts for mitigating spurious features and simplicity bias" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgdBd5rMRA2U1nd8fetuEweTgmHncn49ASMQtPlm6dfsr5V29RwsoUR8UtK4B7oSE1eiIdW-vD-gjCUK4tGZTbsY4XdO0adL2YtAjpgbF1S3mL_Jw3f31SwLKYUtCOLJ807gdXdRmD5iVsrtc_Ii-BiqQacv89vbtRbNAIINa9PhKAF_sDAZu09FLs4599T/s72-c/SiFer%20Hero.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-5966553114967673984</id><published>2024-01-31T13:59:00.000-08:00</published><updated>2024-01-31T13:59:36.056-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Computer Vision"/><category scheme="http://www.blogger.com/atom/ns#" term="Machine Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="On-device Learning"/><title type="text">MobileDiffusion: Rapid text-to-image generation on-device</title><content type="html">&lt;span class="byline-author"&gt;Posted by Yang Zhao, Senior Software Engineer, and Tingbo Hou, Senior Staff Software Engineer, Core ML&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgOndf55Pc7tkXJektbVBEYRsOlxbUVui2uwOdXvuHj9cNpoNw2One4-68fqFNl2_fvv11CcgYfoI1XVQIkpjA9DosaOeqdkIRj9aZZJNoDy8KqB_XCVDtDd_EvT5UGL2ZhXvL2PU3RjN8XBjI0eQe8VIJCKI0-20AG0TKGK58mO9tBZa80P58KSjTU_liK/s1600/InstantTIGO%20hero.png" style="display: none;" /&gt;

&lt;p&gt;
Text-to-image &lt;a href="https://arxiv.org/abs/2006.11239"&gt;diffusion models&lt;/a&gt; have shown exceptional capabilities in generating high-quality images from text prompts. However, leading models feature billions of parameters and are consequently expensive to run, requiring powerful desktops or servers (e.g., &lt;a href="https://stability.ai/news/stable-diffusion-public-release"&gt;Stable Diffusion&lt;/a&gt;, &lt;a href="https://openai.com/research/dall-e"&gt;DALL·E&lt;/a&gt;, and &lt;a href="https://imagen.research.google/"&gt;Imagen&lt;/a&gt;). While recent advancements in inference solutions on &lt;a href="https://blog.research.google/2023/06/speed-is-all-you-need-on-device.html"&gt;Android&lt;/a&gt; via MediaPipe and &lt;a href="https://github.com/apple/ml-stable-diffusion"&gt;iOS&lt;/a&gt; via Core ML have been made in the past year, rapid (sub-second) text-to-image generation on mobile devices has remained out of reach.
&lt;/p&gt; &lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
To that end, in “&lt;a href="https://arxiv.org/abs/2311.16567"&gt;MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices&lt;/a&gt;”, we introduce a novel approach with the potential for rapid text-to-image generation on-device. MobileDiffusion is an efficient latent diffusion model specifically designed for mobile devices. We also adopt &lt;a href="https://arxiv.org/abs/2311.09257"&gt;DiffusionGAN&lt;/a&gt; to achieve one-step sampling during inference, which fine-tunes a pre-trained diffusion model while leveraging a GAN to model the denoising step. We have tested MobileDiffusion on iOS and Android premium devices, and it can run in half a second to generate a 512x512 high-quality image. Its comparably small model size of just 520M parameters makes it uniquely suited for mobile deployment.
&lt;/p&gt;


&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;
  &lt;tr&gt;
    &lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgc9IegNp6IHze1sPewUyoR_WouBi8jMhiThcaavD0SXFld3788eA89uyOP6gpmdCXSZMMuacrgQMJ61ygVJsLfE51tqTmmYS0C-GI9SaF_hEGlhTp_zTFXdW_AgXIP5CLCejKQVCsPrhycF8p_Rj9qQHR0J_kTO8Md7VT5R47IMJHinO6dkHn23lUlU7rf/s800/image2.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="800" data-original-width="369" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgc9IegNp6IHze1sPewUyoR_WouBi8jMhiThcaavD0SXFld3788eA89uyOP6gpmdCXSZMMuacrgQMJ61ygVJsLfE51tqTmmYS0C-GI9SaF_hEGlhTp_zTFXdW_AgXIP5CLCejKQVCsPrhycF8p_Rj9qQHR0J_kTO8Md7VT5R47IMJHinO6dkHn23lUlU7rf/s16000/image2.gif" /&gt;&lt;/a&gt;&lt;/td&gt;
    
    &lt;td&gt;&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&lt;/td&gt;
  
  
  &lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpz0XGSpMH9OVTd865uusar0AeXtu_26HD3tHzJHm2iEVeLYynBhi6pl0tidIYOoJVamc-NplnsNPCNl3vMX-qjqEZCYtndsl-9YjulMpLiDbP3Uws9cZ5ITjb0C3MNaVNC5mh-kbyKZYXn5rxBAuPLaHg_56ZAJfPOrkBfh44goI3CnEW-XZFDUvJgWAV/s800/image5.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="800" data-original-width="369" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpz0XGSpMH9OVTd865uusar0AeXtu_26HD3tHzJHm2iEVeLYynBhi6pl0tidIYOoJVamc-NplnsNPCNl3vMX-qjqEZCYtndsl-9YjulMpLiDbP3Uws9cZ5ITjb0C3MNaVNC5mh-kbyKZYXn5rxBAuPLaHg_56ZAJfPOrkBfh44goI3CnEW-XZFDUvJgWAV/s16000/image5.gif" /&gt;&lt;/a&gt;&lt;/td&gt;
  
  
  
  &lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Rapid text-to-image generation on-device.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Background&lt;/h2&gt;


&lt;p&gt;
The relative inefficiency of text-to-image diffusion models arises from two primary challenges. First, the inherent design of diffusion models requires &lt;a href="https://blog.research.google/2023/06/on-device-diffusion-plugins-for.html"&gt;iterative denoising&lt;/a&gt; to generate images, necessitating multiple evaluations of the model. Second, the complexity of the network architecture in text-to-image diffusion models involves a substantial number of parameters, regularly reaching into the billions and resulting in computationally expensive evaluations. As a result, despite the potential benefits of deploying generative models on mobile devices, such as enhancing user experience and addressing emerging privacy concerns, it remains relatively unexplored within the current literature.
&lt;/p&gt;
&lt;p&gt;
The optimization of inference efficiency in text-to-image diffusion models has been an active research area. Previous studies predominantly concentrate on addressing the first challenge, seeking to reduce the number of function evaluations (NFEs). Leveraging advanced numerical solvers (e.g., &lt;a href="https://arxiv.org/abs/2206.00927"&gt;DPM&lt;/a&gt;) or distillation techniques (e.g., &lt;a href="https://arxiv.org/abs/2202.00512"&gt;progressive distillation&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2303.01469"&gt;consistency distillation&lt;/a&gt;), the number of necessary sampling steps have significantly reduced from several hundreds to single digits. Some recent techniques, like &lt;a href="https://arxiv.org/abs/2311.09257"&gt;DiffusionGAN&lt;/a&gt; and &lt;a href="https://arxiv.org/abs/2311.17042#:~:text=We%20introduce%20Adversarial%20Diffusion%20Distillation,while%20maintaining%20high%20image%20quality."&gt;Adversarial Diffusion Distillation&lt;/a&gt;, even reduce to a single necessary step. 
&lt;/p&gt;
&lt;p&gt;
However, on mobile devices, even a small number of evaluation steps can be slow due to the complexity of model architecture. Thus far, the architectural efficiency of text-to-image diffusion models has received comparatively less attention. A handful of earlier works briefly touches upon this matter, involving the removal of redundant neural network blocks (e.g., &lt;a href="https://snap-research.github.io/SnapFusion/"&gt;SnapFusion&lt;/a&gt;). However, these efforts lack a comprehensive analysis of each component within the model architecture, thereby falling short of providing a holistic guide for designing highly efficient architectures.
&lt;/p&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;MobileDiffusion&lt;/h2&gt;


&lt;p&gt;
Effectively overcoming the challenges imposed by the limited computational power of mobile devices requires an in-depth and holistic exploration of the model's architectural efficiency. In pursuit of this objective, our research undertakes a detailed examination of each constituent and computational operation within Stable Diffusion’s &lt;a href="https://arxiv.org/abs/2112.10752"&gt;UNet architecture&lt;/a&gt;. We present a comprehensive guide for crafting highly efficient text-to-image diffusion models culminating in the MobileDiffusion.
&lt;/p&gt;
&lt;p&gt;
The design of MobileDiffusion follows that of &lt;a href="https://arxiv.org/abs/2112.10752"&gt;latent diffusion models&lt;/a&gt;. It contains three components: a text encoder, a diffusion UNet, and an image decoder. For the text encoder, we use &lt;a href="https://arxiv.org/abs/2103.00020"&gt;CLIP-ViT/L14&lt;/a&gt;, which is a small model (125M parameters) suitable for mobile. We then turn our focus to the diffusion UNet and image decoder. 
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Diffusion UNet&lt;/h3&gt;


&lt;p&gt;
As illustrated in the figure below, diffusion UNets commonly interleave transformer blocks and convolution blocks. We conduct a comprehensive investigation of these two fundamental building blocks. Throughout the study, we control the training pipeline (e.g., data, optimizer) to study the effects of different architectures.
&lt;/p&gt;
&lt;p&gt;
In classic text-to-image diffusion models, a transformer block consists of a self-attention layer (SA) for modeling long-range dependencies among visual features, a cross-attention layer (CA) to capture interactions between text conditioning and visual features, and a feed-forward layer (FF) to post-process the output of attention layers. These transformer blocks hold a pivotal role in text-to-image diffusion models, serving as the primary components responsible for text comprehension. However, they also pose a significant efficiency challenge, given the computational expense of the attention operation, which is quadratic to the sequence length. We follow the idea of &lt;a href="https://arxiv.org/abs/2301.11093"&gt;UViT&lt;/a&gt; architecture, which places more transformer blocks at the bottleneck of the UNet. This design choice is motivated by the fact that the attention computation is less resource-intensive at the bottleneck due to its lower dimensionality. 
&lt;/p&gt;





&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgsshK53k6noqIbabpGMBzYIBCdviXisDoBsD3Houk-lXzN8pZQcusKYBvjWwcwA1Aq5DnWyk01YM9B2RyRZx6HcGgTP-LrW-tnwFwByzlBACN3WggyPYM0Mpyr2OVGVLFhx1uN48aR1g9P4o0joN2STli9VpA_tFMdQ-ikRXVrNpawzB793-unSENR-PIV/s915/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="249" data-original-width="915" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgsshK53k6noqIbabpGMBzYIBCdviXisDoBsD3Houk-lXzN8pZQcusKYBvjWwcwA1Aq5DnWyk01YM9B2RyRZx6HcGgTP-LrW-tnwFwByzlBACN3WggyPYM0Mpyr2OVGVLFhx1uN48aR1g9P4o0joN2STli9VpA_tFMdQ-ikRXVrNpawzB793-unSENR-PIV/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Our UNet architecture incorporates more transformers in the middle, and skips self-attention (SA) layers at higher resolutions.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;p&gt;
Convolution blocks, in particular &lt;a href="https://arxiv.org/abs/1512.03385"&gt;ResNet&lt;/a&gt; blocks, are deployed at each level of the UNet. While these blocks are instrumental for feature extraction and information flow, the associated computational costs, especially at high-resolution levels, can be substantial. One proven approach in this context is &lt;a href="https://arxiv.org/abs/1704.04861"&gt;separable convolution&lt;/a&gt;. We observed that replacing regular convolution layers with lightweight separable convolution layers in the deeper segments of the UNet yields similar performance.
&lt;/p&gt;
&lt;p&gt;
In the figure below, we compare the UNets of several diffusion models. Our MobileDiffusion exhibits superior efficiency in terms of &lt;a href="https://arxiv.org/pdf/2110.12894.pdf"&gt;FLOPs&lt;/a&gt; (floating-point operations) and number of parameters. 
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjXYleITSssbZnLffeh3BzG3tX2qNQNeB__xc-ySks0SPnXsMb2kTLZ0PcE2KWJ4I9FX_QMP32pXd06IuV1kJJSlgp7CuV6dqkXJsiFqo_6xqWXZ1-65p_EPU9gk7G9B4-L2TaKGiD5cahwg428CTmV1dcuQQ_vBTVmP8543IJigIF0qHo8_JaB8h5EuVvl/s1200/image3.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="742" data-original-width="1200" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjXYleITSssbZnLffeh3BzG3tX2qNQNeB__xc-ySks0SPnXsMb2kTLZ0PcE2KWJ4I9FX_QMP32pXd06IuV1kJJSlgp7CuV6dqkXJsiFqo_6xqWXZ1-65p_EPU9gk7G9B4-L2TaKGiD5cahwg428CTmV1dcuQQ_vBTVmP8543IJigIF0qHo8_JaB8h5EuVvl/s16000/image3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Comparison of some diffusion UNets.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;




&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Image decoder&lt;/h3&gt;


&lt;p&gt;
In addition to the UNet, we also optimized the image decoder. We trained a &lt;a href="https://arxiv.org/abs/2012.03715"&gt;variational autoencoder&lt;/a&gt; (VAE) to encode an &lt;a href="https://en.wikipedia.org/wiki/RGB_color_model"&gt;RGB&lt;/a&gt; image to an 8-channel latent variable, with 8× smaller spatial size of the image. A latent variable can be decoded to an image and gets 8× larger in size.  To further enhance efficiency, we design a lightweight decoder architecture by pruning the original’s width and depth. The resulting lightweight decoder leads to a significant performance boost, with nearly 50% latency improvement and better quality. For more details, please refer to our &lt;a href="https://arxiv.org/abs/2311.16567"&gt;paper&lt;/a&gt;.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjT2Nmo7GjGdN0_2dqevJB52RogqnWFDVmFsrusHHxnVf9YQYsdbVkAQvBI3h9SzKZ0TqOQOmnxaZ6z2kdix12tei5oMpD17SY1LoBWqxD1EHgV0ygTb9TV0IFZQtv4dAix378lb8WGv5GGPQIuyStX3gWqn0pjTTXbpIlA0VzYSeiGpkO5bsHhZfjbkR07/s1124/image6.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="789" data-original-width="1124" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjT2Nmo7GjGdN0_2dqevJB52RogqnWFDVmFsrusHHxnVf9YQYsdbVkAQvBI3h9SzKZ0TqOQOmnxaZ6z2kdix12tei5oMpD17SY1LoBWqxD1EHgV0ygTb9TV0IFZQtv4dAix378lb8WGv5GGPQIuyStX3gWqn0pjTTXbpIlA0VzYSeiGpkO5bsHhZfjbkR07/s16000/image6.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;VAE reconstruction. Our VAE decoders have better visual quality than SD (Stable Diffusion).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;br /&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="text-align: center;"&gt;
  &lt;tbody&gt;&lt;tr&gt;
   &lt;td style="text-align: left;"&gt;&lt;b&gt;Decoder&lt;/b&gt;
   &lt;/td&gt;
   &lt;td&gt;&lt;b&gt;&amp;nbsp;&amp;nbsp;#Params (M)&amp;nbsp;&amp;nbsp;&lt;/b&gt;
   &lt;/td&gt;
   &lt;td&gt;&lt;b&gt;&amp;nbsp;&amp;nbsp;PSNR↑&amp;nbsp;&amp;nbsp;&lt;/b&gt;
   &lt;/td&gt;
   &lt;td&gt;&lt;b&gt;&amp;nbsp;&amp;nbsp;SSIM↑&amp;nbsp;&amp;nbsp;&lt;/b&gt;
   &lt;/td&gt;
   &lt;td&gt;&lt;b&gt;&amp;nbsp;&amp;nbsp;LPIPS↓&amp;nbsp;&amp;nbsp;&lt;/b&gt;
   &lt;/td&gt;
  &lt;/tr&gt;
  &lt;tr&gt;
   &lt;td style="text-align: left;"&gt;&lt;b&gt;SD&lt;/b&gt;
   &lt;/td&gt;
   &lt;td&gt;49.5
   &lt;/td&gt;
   &lt;td&gt;26.7
   &lt;/td&gt;
   &lt;td&gt;0.76
   &lt;/td&gt;
   &lt;td&gt;0.037
   &lt;/td&gt;
  &lt;/tr&gt;
  &lt;tr&gt;
   &lt;td style="text-align: left;"&gt;&lt;b&gt;Ours&lt;/b&gt;
   &lt;/td&gt;
   &lt;td&gt;39.3
   &lt;/td&gt;
   &lt;td&gt;30.0
   &lt;/td&gt;
   &lt;td&gt;0.83
   &lt;/td&gt;
   &lt;td&gt;0.032
   &lt;/td&gt;
  &lt;/tr&gt;
  &lt;tr&gt;
   &lt;td style="text-align: left;"&gt;&lt;b&gt;Ours-Lite&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&lt;/b&gt;
   &lt;/td&gt;
   &lt;td&gt;9.8
   &lt;/td&gt;
   &lt;td&gt;30.2
   &lt;/td&gt;
   &lt;td&gt;0.84
   &lt;/td&gt;
   &lt;td&gt;0.032
   &lt;/td&gt;
  &lt;/tr&gt;
&lt;/tbody&gt;&lt;/table&gt;
&lt;br /&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Quality evaluation of VAE decoders. Our lite decoder is much smaller than SD, with better quality metrics, including &lt;a href="https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio"&gt;peak signal-to-noise ratio&lt;/a&gt; (PSNR), &lt;a href="https://en.wikipedia.org/wiki/Structural_similarity"&gt;structural similarity index measure&lt;/a&gt; (SSIM), and &lt;a href="https://arxiv.org/abs/1801.03924"&gt;Learned Perceptual Image Patch Similarity&lt;/a&gt; (LPIPS).&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;One-step sampling&lt;/h3&gt;


&lt;p&gt;
In addition to optimizing the model architecture, we adopt a &lt;a href="https://arxiv.org/abs/2311.09257"&gt;DiffusionGAN hybrid&lt;/a&gt; to achieve one-step sampling. Training DiffusionGAN hybrid models for text-to-image generation encounters several intricacies. Notably, the discriminator, a classifier distinguishing real data and generated data, must make judgments based on both texture and semantics. Moreover, the cost of training text-to-image models can be extremely high, particularly in the case of GAN-based models, where the discriminator introduces additional parameters. Purely GAN-based text-to-image models (e.g., &lt;a href="https://arxiv.org/abs/2301.09515"&gt;StyleGAN-T&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2303.05511"&gt;GigaGAN&lt;/a&gt;) confront similar complexities, resulting in highly intricate and expensive training.
&lt;/p&gt;
&lt;p&gt;
To overcome these challenges, we use a pre-trained diffusion UNet to initialize the generator and discriminator. This design enables seamless initialization with the pre-trained diffusion model. We postulate that the internal features within the diffusion model contain rich information of the intricate interplay between textual and visual data. This initialization strategy significantly streamlines the training.
&lt;/p&gt;
&lt;p&gt;
The figure below illustrates the training procedure. After initialization, a noisy image is sent to the generator for one-step diffusion. The result is evaluated against ground truth with a reconstruction loss, similar to diffusion model training. We then add noise to the output and send it to the discriminator, whose result is evaluated with a GAN loss, effectively adopting the GAN to model a denoising step. By using pre-trained weights to initialize the generator and the discriminator, the training becomes a fine-tuning process, which converges in less than 10K iterations.  
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjnK7SE2-cHSlP-PDmkl_xfjp3sP-kB41r6OvC8Wg6miXnYwdES0INwN19BHWQ_uyXtcBT-872U5J6jLY8yXVtA_W96qkRRPh6Pjvw0n-ZJvjJK91kYTh7H1n4nzy8z1TyrQZlZoZrQUDTo5Qm-6a_2vIVye3aqm7o32qOOXiWXwxDzw_J6cQsOrJ-UILKw/s960/image7.jpg" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="576" data-original-width="960" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjnK7SE2-cHSlP-PDmkl_xfjp3sP-kB41r6OvC8Wg6miXnYwdES0INwN19BHWQ_uyXtcBT-872U5J6jLY8yXVtA_W96qkRRPh6Pjvw0n-ZJvjJK91kYTh7H1n4nzy8z1TyrQZlZoZrQUDTo5Qm-6a_2vIVye3aqm7o32qOOXiWXwxDzw_J6cQsOrJ-UILKw/s16000/image7.jpg" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Illustration of DiffusionGAN fine-tuning.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Results&lt;/h2&gt;


&lt;p&gt;
Below we show example images generated by our MobileDiffusion with DiffusionGAN one-step sampling. With such a compact model (520M parameters in total), MobileDiffusion can generate high-quality diverse images for various domains.
&lt;/p&gt;




&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyDLq1NW7Qvy4_oEqg1pHAMzeBfuei3VadIKZRNkv6ZHnzewVWQU5x76e0bm-QqWVr-_q1W4axBJeyqyCbdRFoUFBYxRxDj3qo7I4-Du6TS2Bez_-mmXzYoHLJk7y5fiKl9PPkHNk_dsvy7ezuAFavW4sYIeYTxhAPAH35FYP5YOceS8NfJey0gpvHUwza/s1728/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1296" data-original-width="1728" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyDLq1NW7Qvy4_oEqg1pHAMzeBfuei3VadIKZRNkv6ZHnzewVWQU5x76e0bm-QqWVr-_q1W4axBJeyqyCbdRFoUFBYxRxDj3qo7I4-Du6TS2Bez_-mmXzYoHLJk7y5fiKl9PPkHNk_dsvy7ezuAFavW4sYIeYTxhAPAH35FYP5YOceS8NfJey0gpvHUwza/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Images generated by our MobileDiffusion&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;p&gt;
We measured the performance of our MobileDiffusion on both iOS and Android devices, using different runtime optimizers. The latency numbers are reported below. We see that MobileDiffusion is very efficient and can run within half a second to generate a 512x512 image. This lightning speed potentially enables many interesting use cases on mobile devices.
&lt;/p&gt;



&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiFkcI7kibwRFhpxTsVmUkAzK38MCeBoTR6fOWyhjnqwPm7x8TwrVn_O0OipsXCbgS4qTtcbtm41Fxi7U_IJjpeuZadWO7cBKkcdrXHniAJgQP4Qk-wOBfnhtwNPxDbzxtM0uxVba3BjwzLa3Lw13-03FoRQbWwf_25KR9GLLkSqIFpnU5aE-6hnomY5IuK/s1184/image8.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="742" data-original-width="1184" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiFkcI7kibwRFhpxTsVmUkAzK38MCeBoTR6fOWyhjnqwPm7x8TwrVn_O0OipsXCbgS4qTtcbtm41Fxi7U_IJjpeuZadWO7cBKkcdrXHniAJgQP4Qk-wOBfnhtwNPxDbzxtM0uxVba3BjwzLa3Lw13-03FoRQbWwf_25KR9GLLkSqIFpnU5aE-6hnomY5IuK/s16000/image8.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Latency measurements (&lt;b&gt;s&lt;/b&gt;) on mobile devices.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
With superior efficiency in terms of latency and size, MobileDiffusion has the potential to be a very friendly option for mobile deployments given its capability to enable a rapid image generation experience while typing text prompts. And we will ensure any application of this technology will be in-line with Google’s &lt;a href="https://ai.google/responsibility/responsible-ai-practices/"&gt;responsible AI practices&lt;/a&gt;.
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgments&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We like to thank our collaborators and contributors that helped bring MobileDiffusion to on-device: Zhisheng Xiao, Yanwu Xu, Jiuqiang Tang, Haolin Jia, Lutz Justen, Daniel Fenner, Ronald Wotzlaw, Jianing Wei, Raman Sarokin, Juhyun Lee, Andrei Kulik, Chuo-Ling Chang, and Matthias Grundmann.&lt;/em&gt;
&lt;/p&gt;


</content><link href="http://blog.research.google/feeds/5966553114967673984/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/01/mobilediffusion-rapid-text-to-image.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5966553114967673984" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5966553114967673984" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/01/mobilediffusion-rapid-text-to-image.html" rel="alternate" title="MobileDiffusion: Rapid text-to-image generation on-device" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgOndf55Pc7tkXJektbVBEYRsOlxbUVui2uwOdXvuHj9cNpoNw2One4-68fqFNl2_fvv11CcgYfoI1XVQIkpjA9DosaOeqdkIRj9aZZJNoDy8KqB_XCVDtDd_EvT5UGL2ZhXvL2PU3RjN8XBjI0eQe8VIJCKI0-20AG0TKGK58mO9tBZa80P58KSjTU_liK/s72-c/InstantTIGO%20hero.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-5144906729109253495</id><published>2024-01-26T11:56:00.000-08:00</published><updated>2024-01-26T11:56:23.553-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Algorithms"/><category scheme="http://www.blogger.com/atom/ns#" term="Deep Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="optimization"/><title type="text">Mixed-input matrix multiplication performance optimizations</title><content type="html">&lt;span class="byline-author"&gt;Posted by Manish Gupta, Staff Software Engineer, Google Research&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhEKJJf1R773hab0veY6zffF2Nf_yfV2mk8YU9yRnuBDD3ak1o0iXecWlJw2x7bL-Ez2MX1c21MXk65VMK5IsoLpJ1H6BTC6k7BvVWl_gHJpJIOG2cm3BwP4V-HCScGHYIynuskbhvu1uorQGprHGbOFmfGI7E5UWemJcZ0xSC3tC5DolBYgyBwugl6OOLr/s1180/matrixhero.png" style="display: none;" /&gt;

&lt;p&gt;
AI-driven technologies are weaving themselves into the fabric of our daily routines, with the potential to enhance our access to knowledge and boost our overall productivity. The backbone of these applications lies in large language models (LLMs).  LLMs are memory-intensive and typically require specialized hardware accelerators to efficiently deliver &lt;a href="https://cloud.google.com/blog/products/compute/the-worlds-largest-distributed-llm-training-job-on-tpu-v5e"&gt;tens of exaflops&lt;/a&gt; of computing power. This blog post shows how we can start addressing the computational challenges by utilizing memory more effectively.
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;
&lt;p&gt;
The bulk of an LLM’s memory and compute are consumed by &lt;a href="https://arxiv.org/pdf/2005.14165.pdf"&gt;weights&lt;/a&gt; in &lt;a href="https://arxiv.org/pdf/2006.16668.pdf"&gt;matrix multiplication&lt;/a&gt; operations. Using narrower &lt;em&gt;&lt;a href="https://en.wikipedia.org/wiki/Primitive_data_type"&gt;data types&lt;/a&gt;&lt;/em&gt; reduces memory consumption. For example, storing weights in the 8-bit &lt;a href="https://en.wikipedia.org/wiki/Integer_(computer_science)"&gt;integer&lt;/a&gt; (i.e., U8 or S8) data type reduces the memory footprint by 4× relative to &lt;a href="https://en.wikipedia.org/wiki/Single-precision_floating-point_format"&gt;single-precision&lt;/a&gt; (F32) and 2× relative to &lt;a href="https://en.wikipedia.org/wiki/Half-precision_floating-point_format"&gt;half-precision&lt;/a&gt; (F16) or &lt;a href="https://en.wikipedia.org/wiki/Bfloat16_floating-point_format"&gt;bfloat16&lt;/a&gt; (BF16). Furthermore, &lt;a href="https://arxiv.org/pdf/2206.01861.pdf"&gt;previous work has&lt;/a&gt; shown that LLM models running matrix multiplications with &lt;em&gt;weights&lt;/em&gt; in S8 and &lt;em&gt;input&lt;/em&gt; in F16 (preserving higher precision of the user-input) is an effective method for increasing the efficiency with acceptable trade-offs in accuracy. This technique is known as &lt;em&gt;weight-only quantization&lt;/em&gt; and requires efficient implementation of matrix multiplication with &lt;em&gt;mixed-inputs&lt;/em&gt;, e.g., half-precision input multiplied with 8-bits integer. Hardware accelerators, including GPUs, support a fixed set of data types, and thus, mixed-input matrix multiplication requires software transformations to map to the hardware operations.
&lt;/p&gt;
&lt;p&gt;
To that end, in this blog we focus on mapping mixed-input matrix multiplication onto the &lt;a href="https://developer.nvidia.com/blog/nvidia-ampere-architecture-in-depth/"&gt;NVIDIA Ampere architecture&lt;/a&gt;. We present software techniques addressing data type conversion and layout conformance to map mixed-input matrix multiplication efficiently onto hardware-supported data types and layouts. Our results show that the overhead of additional work in software is minimal and enables performance close to the peak hardware capabilities. The software techniques described here are released in the open-source &lt;a href="https://github.com/NVIDIA/cutlass/pull/1084"&gt;NVIDIA/CUTLASS&lt;/a&gt; repository. 
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgaLaSxuLbV_5ifXLyJsTGs0WLa23prrxrhX4IKSLZw5l3oSd2SPk5AgZtNgvUY_j-IbOyjttva-XIfkRr1cDBwCXghEz-3Q0G-6236m7_TIgTrm_K2UejYnTnhAEmZtKHq1mN9HKP0xxV8nqSxzTNHG1U0j-cVj236efpR7lSgmt082QEYNwKsGMTRiWZb/s1999/image3.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1159" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgaLaSxuLbV_5ifXLyJsTGs0WLa23prrxrhX4IKSLZw5l3oSd2SPk5AgZtNgvUY_j-IbOyjttva-XIfkRr1cDBwCXghEz-3Q0G-6236m7_TIgTrm_K2UejYnTnhAEmZtKHq1mN9HKP0xxV8nqSxzTNHG1U0j-cVj236efpR7lSgmt082QEYNwKsGMTRiWZb/s16000/image3.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Memory footprint for an 175B parameter LLM model with various data types formats.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;The matrix-multiply-accumulate operation&lt;/h2&gt;


&lt;p&gt;
Modern AI hardware accelerators such as &lt;a href="https://cloud.google.com/tpu/docs/intro-to-tpu#how_a_tpu_works"&gt;Google’s TPU&lt;/a&gt; and &lt;a href="https://www.nvidia.com/en-us/data-center/tensor-cores/"&gt;NVIDIA’s GPU&lt;/a&gt; multiply matrices natively in the hardware by targeting Tensor Cores, which are specialized processing elements to accelerate matrix operations, particularly for AI workloads. In this blog, we focus on NVIDIA Ampere Tensor Cores, which provide the &lt;em&gt;matrix-multiply-accumulate&lt;/em&gt; (&lt;code&gt;&lt;a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma"&gt;mma&lt;/a&gt;&lt;/code&gt;) operation. For the rest of the blog the reference to &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; is for Ampere Tensor Cores. The supported data types, shapes, and data layout of the two input matrices (called operands) for the &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation are fixed in hardware. This means that matrix multiplications with various data types and larger shapes are implemented in the software by tiling the problem onto hardware-supported data types, shapes, and layouts. 

&lt;/p&gt;
&lt;p&gt;
The Tensor Core &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation is defined by specifying two input matrices (e.g., &lt;em&gt;A&lt;/em&gt; &amp;amp; &lt;em&gt;B&lt;/em&gt;, shown below) to produce a result matrix, &lt;em&gt;C&lt;/em&gt;. The &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation natively supports mixed-precision. &lt;em&gt;&lt;a href="https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/"&gt;Mixed-precision Tensor Cores&lt;/a&gt;&lt;/em&gt; allow mixing input (&lt;em&gt;A&lt;/em&gt; and &lt;em&gt;B&lt;/em&gt;) data type with the result (&lt;em&gt;C&lt;/em&gt;) data type. In contrast, &lt;em&gt;mixed-input &lt;/em&gt;matrix multiplication involves mixing the input data types, and it is not supported by the hardware, so it needs to be implemented in the software.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS_vu1tTxHo9Gy6Mywfx1xbQ0G6XTpOOQ04-l-Nw_rM7qOAM9kXg_qDjIakIpx-IclRmfR96cTGGExo2k9fxnVdltW4I9nb7RHloRtqWFMFeOtZ68Yr5wve9uLTIsZKA3GxB_VaNo98Gfsa7zGGP0dCrjebZ0Fq1dutfoxoy25eByHXorHCwTTiqsFzw6M/s1039/image5.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="668" data-original-width="1039" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS_vu1tTxHo9Gy6Mywfx1xbQ0G6XTpOOQ04-l-Nw_rM7qOAM9kXg_qDjIakIpx-IclRmfR96cTGGExo2k9fxnVdltW4I9nb7RHloRtqWFMFeOtZ68Yr5wve9uLTIsZKA3GxB_VaNo98Gfsa7zGGP0dCrjebZ0Fq1dutfoxoy25eByHXorHCwTTiqsFzw6M/s16000/image5.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Tensor Core operation of M-by-N-by-K on input matrix A of M-by-K and matrix B of K-by-N produces output matrix C of M-by-N.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;






&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Challenges of mixed-input matrix multiplication&lt;/h2&gt;


&lt;p&gt;
To simplify the discussion, we restrict to a specific example of mixed-input matrix multiplication: F16 for user input and U8 for the model weights (written as F16 * U8). The techniques described here work for various combinations of mixed-input data types. 
&lt;/p&gt;
&lt;p&gt;
A GPU programmer can access a &lt;a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy"&gt;hierarchy of memory&lt;/a&gt;, including global memory, shared memory, and registers, which are arranged in order of decreasing capacity but increasing speed. NVIDIA Ampere Tensor Core &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operations consume input matrices from registers. Furthermore, input and output matrices are required to conform to a layout of data within a group of 32 threads known as a &lt;em&gt;warp&lt;/em&gt;. The supported data type &lt;em&gt;and&lt;/em&gt; layout within a warp are fixed for an &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation, so to implement mixed-input multiplication efficiently, it is necessary to solve the challenges of data type conversion and layout conformance in software. 
&lt;/p&gt;

&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Data type conversion &lt;/h3&gt;


&lt;p&gt;
The &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation requires two input matrices with the same data type. Thus, mixed-input matrix multiplication, where one of the operands is stored in U8 in global memory and other in F16, requires a data type conversion from U8 to F16. The conversion will bring two operands to F16, mapping the &lt;em&gt;mixed-input&lt;/em&gt; matrix multiplication to hardware-supported &lt;em&gt;mixed-precision&lt;/em&gt; Tensor Cores. Given the large number of weights, there are a large number of such operations, and our techniques show how to reduce their latency and improve  performance.
&lt;/p&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Layout conformance &lt;/h3&gt;


&lt;p&gt;
The &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation also requires the layout of two input matrices, within the registers of a warp, to be conformat with hardware specification. The layout for the input matrix &lt;em&gt;B&lt;/em&gt; of U8 data type in mixed-input matrix multiplication (F16 * U8) needs to conform with the converted F16 data type. This is called &lt;em&gt;layout conformance&lt;/em&gt; and needs to be achieved in the software. 
&lt;/p&gt;
&lt;p&gt;
The figure below shows an &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation consuming matrix &lt;em&gt;A&lt;/em&gt; and matrix &lt;em&gt;B&lt;/em&gt; from registers to produce matrix &lt;em&gt;C&lt;/em&gt; in registers, distributed across one warp. The thread &lt;em&gt;T0&lt;/em&gt; is highlighted and zoomed in to show the weight matrix &lt;em&gt;B&lt;/em&gt; goes through data type conversion and needs a layout conformance to be able to map to the hardware-supported Tensor Core operation.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMMvieW8Uyta8c4afsNM7SgyZtlB2ra7G7aBG4z7D73rn-T7NHge0J1zfK7A_edL9tsQIthWVtEd0hZmwAjfO5C-XM6d5hNkv8IEBlpRxHilOxFgjYi27qauWFAQTl5wV8ixQ9MrfvqpuEQrdFuqDtjPJESG795s6cH3FlPJIVS4TuvKo0gmd8L1HwOJ_6/s1999/image4.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1240" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMMvieW8Uyta8c4afsNM7SgyZtlB2ra7G7aBG4z7D73rn-T7NHge0J1zfK7A_edL9tsQIthWVtEd0hZmwAjfO5C-XM6d5hNkv8IEBlpRxHilOxFgjYi27qauWFAQTl5wV8ixQ9MrfvqpuEQrdFuqDtjPJESG795s6cH3FlPJIVS4TuvKo0gmd8L1HwOJ_6/s16000/image4.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;The mapping of mixed-input (F32 = F16 * U8) operation in software to natively supported warp-level Tensor Cores in hardware (F32 = F16 * F16). (Original figure source &lt;a href="https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/"&gt;Developing CUDA kernels to push Tensor Cores to the Absolute Limit on NVIDIA A100&lt;/a&gt;.)&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;







&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Software strategies addressing challenges&lt;/h2&gt;


&lt;p&gt;
A typical data type conversion involves a sequence of operations on 32-bit registers, shown below. Each rectangular block represents a register and the adjoining text are the operations. The entire sequence shows the conversion from 4xU8 to 2x(2xF16). The sequence involves roughly 10 operations. 
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyJ4C214tiBhdjds0fWCV9EWh8X_UEDQlFqkpeoo6CZR3QMMrWyqi5mfRjvHLtbHH55J4hM5oRxe0HouGnbE3KuPbmh8MKk-TtDMMZv1YMKPv-Q4gYAr5l3ZXdTIPUHKs7f8wfCgr3XPe6_jUO7u12pGEmZVFiAGn_LCOlUlQQRSF7_r7jlOrPJW9Oc4V1/s947/image1.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="836" data-original-width="947" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyJ4C214tiBhdjds0fWCV9EWh8X_UEDQlFqkpeoo6CZR3QMMrWyqi5mfRjvHLtbHH55J4hM5oRxe0HouGnbE3KuPbmh8MKk-TtDMMZv1YMKPv-Q4gYAr5l3ZXdTIPUHKs7f8wfCgr3XPe6_jUO7u12pGEmZVFiAGn_LCOlUlQQRSF7_r7jlOrPJW9Oc4V1/s16000/image1.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;&lt;code&gt;&lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L760"&gt;NumericArrayConvertor&lt;/a&gt;&lt;/code&gt; from 4xU8 to 2x(2xF16) in 32-bit registers.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;






&lt;p&gt;
There are many ways of achieving layout conformance. Two of the existing solutions are:
&lt;/p&gt;
&lt;ol&gt;

&lt;li&gt;&lt;em&gt;Narrower bitwidth shared memory loads&lt;/em&gt;: In this approach, threads issue narrow bitwidth memory loads moving the U8 data from shared memory to registers. This results in &lt;em&gt;two&lt;/em&gt; 32-bit registers, with each register containing 2xF16 values (shown above for the matrix &lt;em&gt;B&lt;/em&gt;’s thread &lt;em&gt;T0&lt;/em&gt;). The narrower shared memory load achieves layout conformance directly into registers without needing any shuffles; however, it does not utilize the full shared memory bandwidth.

&lt;/li&gt;&lt;li&gt;&lt;em&gt;Pre-processing in global memory&lt;/em&gt;: An &lt;a href="https://arxiv.org/pdf/2211.10017.pdf"&gt;alternative strategy&lt;/a&gt; involves rearranging the data within the global memory (one level above the shared memory in &lt;a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy"&gt;memory hierarchy&lt;/a&gt;), allowing wider shared memory loads. This approach maximizes the shared memory bandwidth utilization and ensures that the data is loaded in a conformant layout directly in the registers. Although the rearrangement process can be executed offline prior to the LLM deployment, ensuring no impact on the application performance, it introduces an additional, non-trivial hardware-specific pre-processing step that requires an extra program to rearrange the data. &lt;a href="https://github.com/NVIDIA/FasterTransformer"&gt;NVIDIA/FasterTransformer&lt;/a&gt; adopts this method to effectively address layout conformance challenges.
&lt;/li&gt;
&lt;/ol&gt;


&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Optimized software strategies&lt;/h2&gt;


&lt;p&gt;
To further optimize and reduce the overhead of data type conversion and layout conformance, we have implemented &lt;code&gt;&lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2514"&gt;FastNumericArrayConvertor&lt;/a&gt;&lt;/code&gt; and &lt;code&gt;&lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h#L120"&gt;FragmentShuffler&lt;/a&gt;&lt;/code&gt;, respectively. 

&lt;/p&gt;&lt;p&gt;
&lt;code&gt;FastNumericArrayConvertor&lt;/code&gt; operates on 4xU8 in 32-bit registers without unpacking individual 1xU8 values. Furthermore, it uses less expensive arithmetic operations which reduces the number of instructions and increases the speed of the conversion. 
&lt;/p&gt;
&lt;p&gt;
The conversion sequence for &lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2514"&gt;U8-to-F16&lt;/a&gt; is shown below. The operations use packed 32b registers, avoiding explicit unpacking and packing. &lt;code&gt;FastNumericArrayConvertor&lt;/code&gt; uses the &lt;code&gt;&lt;a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt"&gt;permute byte&lt;/a&gt;&lt;/code&gt; to rearrange bytes of 4xU8 into two registers. Additionally, &lt;code&gt;FastNumericArrayConvertor&lt;/code&gt; does not use expensive integer to floating-point conversion instructions and employs vectorized operations to obtain the packed results in &lt;em&gt;two&lt;/em&gt; 32-bit registers containing  2x(2xF16) values. The &lt;code&gt;FastNumericArrayConvertor&lt;/code&gt; for U8-to-F16 approximately uses six operations, a 1.6× reduction relative to the approach shown above.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgRhtLljZ8wfnfnyXQsYZlNMDZ-cUqCV7wPvGimtPtU3JcKJLv6lCDT_PfBBmyp0TuHRgFIZ2cbgEDeL5bqke4FGUcpGMbAhcIBJxQcpcuWZIlqG1yXOHPf5BivF26_qlDnR9W2Y3RVE36ZB7rEGZO3x2Xva7-rqBZkoI7l4gnzBWLYfIrmhFBNN8DpaoEA/s1392/image201.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="733" data-original-width="1392" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgRhtLljZ8wfnfnyXQsYZlNMDZ-cUqCV7wPvGimtPtU3JcKJLv6lCDT_PfBBmyp0TuHRgFIZ2cbgEDeL5bqke4FGUcpGMbAhcIBJxQcpcuWZIlqG1yXOHPf5BivF26_qlDnR9W2Y3RVE36ZB7rEGZO3x2Xva7-rqBZkoI7l4gnzBWLYfIrmhFBNN8DpaoEA/s16000/image201.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;&lt;code&gt;FastNumericArrayConvertor&lt;/code&gt; utilizes &lt;code&gt;&lt;a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt"&gt;permute bytes&lt;/a&gt;&lt;/code&gt; and packed arithmetic, reducing the number of instructions in the data type conversion.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;


&lt;p&gt;
&lt;code&gt;FragmentShuffler&lt;/code&gt; handles the layout conformance by shuffling data in a way that allows the use of wider bitwidth load operation, increasing shared memory bandwidth utilization and reducing the total number of operations. 
&lt;/p&gt;
&lt;p&gt;
NVIDIA Ampere architecture provides a load matrix instruction (&lt;code&gt;&lt;a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix"&gt;ldmatrix&lt;/a&gt;&lt;/code&gt;). The &lt;span style="color: #54863f;"&gt;&lt;code&gt;ldmatrix&lt;/code&gt;&lt;/span&gt; is a warp-level operation, where 32 threads of a warp move the data from shared memory to registers in the &lt;em&gt;shape&lt;/em&gt; and &lt;em&gt;layout&lt;/em&gt; that &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; matrix &lt;em&gt;A&lt;/em&gt; and &lt;em&gt;B&lt;/em&gt; consume. The use of &lt;span style="color: #54863f;"&gt;&lt;code&gt;ldmatrix&lt;/code&gt;&lt;/span&gt; &lt;em&gt;reduces&lt;/em&gt; the number of load instructions and &lt;em&gt;increases&lt;/em&gt; the memory bandwidth utilization. Since the &lt;span style="color: #54863f;"&gt;&lt;code&gt;ldmatrix&lt;/code&gt;&lt;/span&gt; instruction moves U8 data to registers, the layout after the load conforms with U8*U8 &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation, and not with F16*F16 &lt;span style="color: #54863f;"&gt;&lt;code&gt;mma&lt;/code&gt;&lt;/span&gt; operation. We implemented &lt;code&gt;FragmentShuffler&lt;/code&gt; to rearrange the data within registers using shuffle (&lt;code&gt;&lt;a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions"&gt;shfl.sync&lt;/a&gt;)&lt;/code&gt; operations to achieve the layout conformance. 

&lt;/p&gt;&lt;p&gt;
The most significant contribution of this work is to achieve layout conformance through register shuffles, avoiding offline pre-processing in global memory or narrower bitwidth shared memory loads. Furthermore, we provide implementations for &lt;code&gt;FastNumericArrayConvertor&lt;/code&gt; covering data type conversion from &lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2514"&gt;U8-to-F16&lt;/a&gt;, &lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2448"&gt;S8-to-F16&lt;/a&gt;, &lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2546"&gt;U8-to-BF16&lt;/a&gt;, and &lt;a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2588"&gt;S8-to-BF16&lt;/a&gt;.
&lt;/p&gt;



&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Performance results&lt;/h2&gt;


&lt;p&gt;
We measured the performance of eight mixed-input variants of &lt;em&gt;our method&lt;/em&gt; (shown below in blue and red; varying the data types of matrix &lt;em&gt;A&lt;/em&gt; and &lt;em&gt;B&lt;/em&gt;) and two &lt;em&gt;mixed-precision&lt;/em&gt; data types (shown in green) on an NVIDIA A100 SXM chip. The performance results are shown in &lt;a href="https://en.wikipedia.org/wiki/FLOPS"&gt;FLOPS&lt;/a&gt; (higher is better). Notably, the first eight matrix-multipications require additional operations relative to the last two, because the mixed-precision variants directly target hardware-accelerated Tensor Core operations and do not need data type conversion and layout conformance. Even so, our approach demonstrates mixed-input matrix multiplication performance only slightly below or on par with mixed-precision. 
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-Dq_2LmFUlg0KlNIJvFufCUMZujNc9LcoMnSURpGQwGbM75vXuS-Nm9ZH-7ItgWmZaBSUS3yawN0u3K21tbWTdijU4fVNgEyS33jOztyGfvNvLEw6IBiJO3JSmpctQtN8tvZmagEYQNSP3mmBQnXJ8GeNlQymbeqrKjFycjkKnHL_5FC8V6WR858byfm_/s1999/image2.png" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="1180" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-Dq_2LmFUlg0KlNIJvFufCUMZujNc9LcoMnSURpGQwGbM75vXuS-Nm9ZH-7ItgWmZaBSUS3yawN0u3K21tbWTdijU4fVNgEyS33jOztyGfvNvLEw6IBiJO3JSmpctQtN8tvZmagEYQNSP3mmBQnXJ8GeNlQymbeqrKjFycjkKnHL_5FC8V6WR858byfm_/s16000/image2.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Mixed-input matrix multiplication  performance on NVIDIA A100 40GB SMX4 chip for a compute-bound matrix problem shape &lt;code&gt;m=3456, n=4096, k=2048.&lt;/code&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;






&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We would like to mention several folks who have contributed through technical brainstorming and improving the blog post including, Quentin Colombet, Jacques Pienaar, Allie Culp, Calin Cascaval, Ashish Gondimalla, Matt Walsh, Marek Kolodziej, and Aman Bhatia. We would like to thank our NVIDIA partners Rawn Henry, Pradeep Ramani, Vijay Thakkar, Haicheng Wu, Andrew Kerr, Matthew Nicely, and Vartika Singh.&lt;/em&gt;
&lt;/p&gt;
</content><link href="http://blog.research.google/feeds/5144906729109253495/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/01/mixed-input-matrix-multiplication.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5144906729109253495" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/5144906729109253495" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/01/mixed-input-matrix-multiplication.html" rel="alternate" title="Mixed-input matrix multiplication performance optimizations" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhEKJJf1R773hab0veY6zffF2Nf_yfV2mk8YU9yRnuBDD3ak1o0iXecWlJw2x7bL-Ez2MX1c21MXk65VMK5IsoLpJ1H6BTC6k7BvVWl_gHJpJIOG2cm3BwP4V-HCScGHYIynuskbhvu1uorQGprHGbOFmfGI7E5UWemJcZ0xSC3tC5DolBYgyBwugl6OOLr/s72-c/matrixhero.png" width="72"/><thr:total>0</thr:total></entry><entry><id>tag:blogger.com,1999:blog-8474926331452026626.post-1418736582601940076</id><published>2024-01-23T14:27:00.000-08:00</published><updated>2024-01-23T14:27:09.785-08:00</updated><category scheme="http://www.blogger.com/atom/ns#" term="Deep Learning"/><category scheme="http://www.blogger.com/atom/ns#" term="Graphs"/><title type="text">Exphormer: Scaling transformers for graph-structured data</title><content type="html">&lt;span class="byline-author"&gt;Posted by Ameya Velingker, Research Scientist, Google Research, and Balaji Venkatachalam, Software Engineer, Google&lt;/span&gt;

&lt;img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhbovKreBr7RlKc4L36E6rLqiZBZzJSq5GLijCkomHREon5tYXd-7C2pppMXnL5Mj2d82kZGnPlarrrMzQOfRnN8kVvqDh1GnadIJ-hbaaS8VjYzCpaD-DgYor5cKx-OhTGZk9iCy5MjtwG2Q9eTyQiipDr5ViMdl2vkxfbLzWnB3wmLb8YfvVsTJ1FnOmw/s1600/EXPHORMER%2005large.gif" style="display: none;" /&gt;

&lt;p&gt;
&lt;a href="https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)"&gt;Graphs&lt;/a&gt;, in which objects and their relations are represented as nodes (or vertices) and edges (or links) between pairs of nodes, are ubiquitous in computing and machine learning (ML). For example, social networks, road networks, and molecular structure and interactions are all domains in which underlying datasets have a natural graph structure. ML can be used to learn the properties of nodes, edges, or entire graphs. 
&lt;/p&gt;
&lt;a name='more'&gt;&lt;/a&gt;

&lt;p&gt;
A common approach to learning on graphs are &lt;a href="https://distill.pub/2021/gnn-intro/"&gt;graph neural networks&lt;/a&gt; (GNNs), which operate on graph data by applying an optimizable transformation on node, edge, and global attributes. The most typical class of GNNs operates via a &lt;a href="https://wandb.ai/graph-neural-networks/spatial/reports/An-Introduction-to-Message-Passing-Graph-Neural-Networks--VmlldzoyMDI2NTg2"&gt;message-passing&lt;/a&gt; framework, whereby each layer aggregates the representation of a node with those of its immediate neighbors.
&lt;/p&gt;
&lt;p&gt;
Recently, &lt;a href="https://arxiv.org/abs/2012.09699"&gt;graph transformer models&lt;/a&gt; have emerged as a popular alternative to message-passing GNNs. These models build on the success of &lt;a href="https://en.wikipedia.org/wiki/Transformer_(machine-learning_model)"&gt;Transformer architectures&lt;/a&gt; in natural language processing (NLP), adapting them to graph-structured data. The attention mechanism in graph transformers can be modeled by an interaction graph, in which edges represent pairs of nodes that attend to each other. Unlike message passing architectures, graph transformers have an interaction graph that is separate from the input graph. The typical interaction graph is a complete graph, which signifies a full attention mechanism&lt;em&gt; &lt;/em&gt;that models direct interactions between all pairs of nodes. However, this creates quadratic computational and memory bottlenecks that limit the applicability of graph transformers to datasets on small graphs with at most a few thousand nodes. Making graph transformers scalable has been considered one of the most important research directions in the field (see &lt;a href="https://towardsdatascience.com/graph-ml-in-2022-where-are-we-now-f7f8242599e0"&gt;the first open problem here&lt;/a&gt;).
&lt;/p&gt;
&lt;p&gt;
A natural remedy is to use a &lt;em&gt;sparse&lt;/em&gt; interaction graph with fewer edges. &lt;a href="https://dl.acm.org/doi/10.1145/3530811"&gt;Many sparse and efficient transformers have been proposed&lt;/a&gt; to eliminate the quadratic bottleneck for sequences, however, they do not generally extend to graphs in a principled manner.
&lt;/p&gt;
&lt;p&gt;
In “&lt;a href="https://arxiv.org/abs/2303.06147"&gt;Exphormer: Sparse Transformers for Graphs&lt;/a&gt;”, presented at &lt;a href="https://icml.cc/Conferences/2023/Dates"&gt;ICML 2023&lt;/a&gt;, we address the scalability challenge by introducing a sparse attention framework for transformers that is designed specifically for graph data. The Exphormer framework makes use of expander graphs, a powerful tool from &lt;a href="https://en.wikipedia.org/wiki/Spectral_graph_theory"&gt;spectral graph theory&lt;/a&gt;, and is able to achieve strong empirical results on a wide variety of datasets. Our implementation of Exphormer is now available on &lt;a href="https://github.com/hamed1375/Exphormer"&gt;GitHub&lt;/a&gt;.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Expander graphs&lt;/h2&gt;


&lt;p&gt;
A key idea at the heart of Exphormer is the use of &lt;a href="https://en.wikipedia.org/wiki/Expander_graph"&gt;expander graphs&lt;/a&gt;, which are sparse yet well-connected graphs that have some useful properties — 1) the matrix representation of the graphs have similar linear-algebraic properties as a complete graph, and 2) they exhibit rapid mixing of random walks, i.e., a small number of steps in a random walk from any starting node is enough to ensure convergence to a “stable” distribution on the nodes of the graph. Expanders have found applications to diverse areas, such as algorithms, pseudorandomness, complexity theory, and error-correcting codes.
&lt;/p&gt;
&lt;p&gt;
A common class of expander graphs are &lt;em&gt;d&lt;/em&gt;-regular expanders, in which there are &lt;em&gt;d&lt;/em&gt; edges from every node (i.e., every node has degree &lt;em&gt;d&lt;/em&gt;). The quality of an expander graph is measured by its &lt;em&gt;spectral gap&lt;/em&gt;, an algebraic property of its &lt;a href="https://en.wikipedia.org/wiki/Adjacency_matrix"&gt;adjacency matrix&lt;/a&gt; (a matrix representation of the graph in which rows and columns are indexed by nodes and entries indicate whether pairs of nodes are connected by an edge). Those that maximize the spectral gap are known as &lt;a href="https://en.wikipedia.org/wiki/Ramanujan_graph"&gt;Ramanujan graphs&lt;/a&gt; — they achieve a gap of &lt;em&gt;d&lt;/em&gt; - 2*√(&lt;em&gt;d&lt;/em&gt;-1), which is essentially the best possible among &lt;em&gt;d&lt;/em&gt;-regular graphs. A number of deterministic and randomized constructions of Ramanujan graphs have been proposed over the years for various values of &lt;em&gt;d&lt;/em&gt;. We use a &lt;a href="https://arxiv.org/abs/cs/0405020"&gt;randomized expander construction of Friedman&lt;/a&gt;, which produces near-Ramanujan graphs.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg495FZQZ12yMiNhU8C7XUKEJ88H5_v2PPrzhwcDOVnSaVEtdCXaL7py-LzwZZkybKwIaePLHKpdmD6qALfskdjeaA8ML9QYHMwWkxz2ZnhWYqoV1PpnNgbRRfm0pSVYJVrtUpONyyF5PfswJ_QoxD-9vI9F3rF6VQbIRDDIbgvOFc35vTEF9uxizKNpli9/s843/image1.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="843" data-original-width="800" height="320" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg495FZQZ12yMiNhU8C7XUKEJ88H5_v2PPrzhwcDOVnSaVEtdCXaL7py-LzwZZkybKwIaePLHKpdmD6qALfskdjeaA8ML9QYHMwWkxz2ZnhWYqoV1PpnNgbRRfm0pSVYJVrtUpONyyF5PfswJ_QoxD-9vI9F3rF6VQbIRDDIbgvOFc35vTEF9uxizKNpli9/s320/image1.gif" width="304" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;&lt;span id="docs-internal-guid-2920b38b-7fff-2fa8-a3cd-06dfd3ba9968"&gt;&lt;span face="Arial, sans-serif" style="font-size: 10pt; font-style: italic; font-variant-alternates: normal; font-variant-east-asian: normal; font-variant-numeric: normal; font-variant-position: normal; vertical-align: baseline; white-space-collapse: preserve;"&gt;Expander graphs are at the heart of Exphormer. A good expander is sparse yet exhibits rapid mixing of random walks, making its global connectivity suitable for an interaction graph in a graph transformer model.&lt;/span&gt;&lt;/span&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;Exphormer replaces the dense, fully-connected interaction graph of a standard Transformer with edges of a sparse &lt;em&gt;d&lt;/em&gt;-regular expander graph. Intuitively, the spectral approximation and mixing properties of an expander graph allow distant nodes to communicate with each other after one stacks multiple attention layers in a graph transformer architecture, even though the nodes may not attend to each other directly. Furthermore, by ensuring that &lt;em&gt;d&lt;/em&gt; is constant (independent of the size of the number of nodes), we obtain a linear number of edges in the resulting interaction graph.&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Exphormer: Constructing a sparse interaction graph&lt;/h2&gt;


&lt;p&gt;
Exphormer combines expander edges with the input graph and virtual nodes. More specifically, the sparse attention mechanism of Exphormer builds an interaction graph consisting of three types of edges:
&lt;/p&gt;
&lt;ul&gt;

&lt;li&gt;Edges from the input graph (&lt;em&gt;local attention&lt;/em&gt;)

&lt;/li&gt;&lt;li&gt;Edges from a constant-degree expander graph (&lt;em&gt;expander attention&lt;/em&gt;)

&lt;/li&gt;&lt;li&gt;Edges from every node to a small set of virtual nodes (&lt;em&gt;global attention&lt;/em&gt;)
&lt;/li&gt;
&lt;/ul&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS7VdL6OcWCmXd-wTtx-qs_nA7qTYZFJOTHS7RZNS3Io_w4km3NM4opPsQBXu1u50KjDA43CsG0hoi1l7I9gq_KGBMvwKEjlWQKBzCeytLQHujF-4K4r9E4F4Q0APvw7le4twjGbDyEiVfEzhbsovhzk2_g4Xd4jwCo66HW7xbnLvm3WPBsHaoq-hDAYX8/s800/image1.gif" style="margin-left: auto; margin-right: auto;"&gt;&lt;img border="0" data-original-height="430" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS7VdL6OcWCmXd-wTtx-qs_nA7qTYZFJOTHS7RZNS3Io_w4km3NM4opPsQBXu1u50KjDA43CsG0hoi1l7I9gq_KGBMvwKEjlWQKBzCeytLQHujF-4K4r9E4F4Q0APvw7le4twjGbDyEiVfEzhbsovhzk2_g4Xd4jwCo66HW7xbnLvm3WPBsHaoq-hDAYX8/s16000/image1.gif" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;&lt;span id="docs-internal-guid-ac11d16d-7fff-62da-cf18-7ba830f677d3"&gt;&lt;span face="Arial, sans-serif" style="font-size: 10pt; font-style: italic; font-variant-alternates: normal; font-variant-east-asian: normal; font-variant-numeric: normal; font-variant-position: normal; vertical-align: baseline; white-space-collapse: preserve;"&gt;Exphormer builds an interaction graph by combining three types of edges. The resulting graph has good connectivity properties and retains the inductive bias of the input dataset graph while still remaining sparse.&lt;/span&gt;&lt;/span&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;p&gt;
  Each component serves a specific purpose: the edges from the input graph retain the inductive bias from the input graph structure (which typically gets lost in a fully-connected attention module). Meanwhile, expander edges allow good global connectivity and random walk mixing properties (which spectrally approximate the complete graph with far fewer edges). Finally, virtual nodes serve as global “memory sinks” that can directly communicate with every node. While this results in additional edges from each virtual node equal to the number of nodes in the input graph, the resulting graph is still sparse. The degree of the expander graph and the number of virtual nodes are hyperparameters to tune for improving the quality metrics.
&lt;/p&gt;
&lt;p&gt;
Furthermore, since we use an expander graph of constant degree and a small constant number of virtual nodes for the global attention, the resulting sparse attention mechanism is linear in the size of the original input graph, i.e., it models a number of direct interactions on the order of the total number of nodes and edges.
&lt;/p&gt;
&lt;p&gt;
We additionally show that Exphormer is as expressive as the dense transformer and obeys universal approximation properties. In particular, when the sparse attention graph of Exphormer is augmented with self loops (edges connecting a node to itself), it can universally approximate continuous functions [&lt;a href="https://arxiv.org/abs/1912.10077"&gt;1&lt;/a&gt;, &lt;a href="https://arxiv.org/abs/2006.04862"&gt;2&lt;/a&gt;].
&lt;/p&gt;
&lt;div style="line-height: 40%;"&gt;
    &lt;br /&gt;
&lt;/div&gt;
&lt;h3&gt;Relation to sparse Transformers for sequences&lt;/h3&gt;


&lt;p&gt;
It is interesting to compare Exphormer to sparse attention methods for sequences. Perhaps the architecture most conceptually similar to our approach is &lt;a href="https://blog.research.google/2021/03/constructing-transformers-for-longer.html"&gt;BigBird&lt;/a&gt;, which  builds an interaction graph by combining different components. BigBird also uses virtual nodes, but, unlike Exphormer, it uses window attention and random attention from an &lt;a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model"&gt;Erdős-Rényi&lt;/a&gt; random graph model for the remaining components.
&lt;/p&gt;
&lt;p&gt;
Window attention in BigBird looks at the tokens surrounding a token in a sequence — the local neighborhood attention in Exphormer can be viewed as a generalization of window attention to graphs.
&lt;/p&gt;
&lt;p&gt;
The Erdős-Rényi graph on &lt;em&gt;n&lt;/em&gt; nodes, &lt;em&gt;G(n, p)&lt;/em&gt;, which connects every pair of nodes independently with probability &lt;em&gt;p&lt;/em&gt;, also functions as an expander graph for suitably high &lt;em&gt;p&lt;/em&gt;. However, a superlinear number of edges (Ω(&lt;em&gt;n&lt;/em&gt; log &lt;em&gt;n&lt;/em&gt;)) is needed to ensure that an Erdős-Rényi graph is connected, let alone a good expander. On the other hand, the expanders used in Exphormer have only a &lt;em&gt;linear&lt;/em&gt; number of edges.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Experimental results&lt;/h2&gt;


&lt;p&gt;
Earlier works have shown the use of full graph Transformer-based models on datasets with graphs of size up to 5,000 nodes. To evaluate the performance of Exphormer, we build upon the celebrated &lt;a href="https://github.com/rampasek/GraphGPS"&gt;GraphGPS framework&lt;/a&gt; [&lt;a href="https://arxiv.org/abs/2205.12454"&gt;3&lt;/a&gt;], which combines both message passing and graph transformers and achieves state-of-the-art performance on a number of datasets. We show that replacing dense attention with Exphormer for the graph attention component in the GraphGPS framework allows one to achieve models with comparable or better performance, often with fewer trainable parameters.
&lt;/p&gt;
&lt;p&gt;
Furthermore, Exphormer notably allows graph transformer architectures to scale well beyond the usual graph size limits mentioned above. Exphormer can scale up to datasets of 10,000+ node graphs, such as the &lt;a href="https://arxiv.org/abs/1811.05868"&gt;Coauthor dataset&lt;/a&gt;, and even beyond to larger graphs such as the well-known &lt;a href="https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv"&gt;ogbn-arxiv dataset&lt;/a&gt;, a citation network, which consists of 170K nodes and 1.1 million edges.
&lt;/p&gt;

&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi-HJWH6mqX6N9ytZPbz6wawfMLzF2ey50Ot2BcowvPbQ3FaNwhlEZ3htvDbhq1C6ckLykf0yk3A1sIG0aPGaT8G_aSLj_A-AOfl8NIZdygdkn0C26RzZS9d-9KjyP1f_Zy7suN-iqvYR4zSCgqCXrhP8hVIirUgi6VGEBGx9I_AZikzc_ACKskBMBMPoSw/s1600/ExphormerPerformance.png" style="display: block; margin-left: auto; margin-right: auto; padding: 1em 0px; text-align: center;"&gt;&lt;img alt="" border="0" data-original-height="190" data-original-width="1522" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi-HJWH6mqX6N9ytZPbz6wawfMLzF2ey50Ot2BcowvPbQ3FaNwhlEZ3htvDbhq1C6ckLykf0yk3A1sIG0aPGaT8G_aSLj_A-AOfl8NIZdygdkn0C26RzZS9d-9KjyP1f_Zy7suN-iqvYR4zSCgqCXrhP8hVIirUgi6VGEBGx9I_AZikzc_ACKskBMBMPoSw/s1600/ExphormerPerformance.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Results comparing Exphormer to standard GraphGPS on the five &lt;a href="https://arxiv.org/abs/2206.08164"&gt;Long Range Graph Benchmark&lt;/a&gt; datasets. We note that Exphormer achieved state-of-the-art results on four of the five datasets (PascalVOC-SP, COCO-SP, Peptides-Struct, PCQM-Contact) at the time of the paper’s publication.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;

&lt;!--&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;&lt;tbody&gt;&lt;tr&gt;&lt;td style="text-align: center;"&gt;&lt;a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhDbVRMNKr2z64PowKGcaM4NDeiIfzrpfyXe02tRD8tpr_DS99oIjewDwOZZJkNgOr7ZSYwsE5jVqpwOz0Tj2z68SkQzCWtZrhC3cXf2WWfJEZmSfOq3xlGIjdfx-9V0CkbYYv6LU63i1B-suztAyK0Dx8udq2SYSX4TEeP5Erw021KZY8L4FEVNV3BOXaL/s1600/ExphormerPerformance.png" style="display: block; margin-left: auto; margin-right: auto; padding: 1em 0px; text-align: center;"&gt;&lt;img alt="" border="0" data-original-height="202" data-original-width="1655" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhDbVRMNKr2z64PowKGcaM4NDeiIfzrpfyXe02tRD8tpr_DS99oIjewDwOZZJkNgOr7ZSYwsE5jVqpwOz0Tj2z68SkQzCWtZrhC3cXf2WWfJEZmSfOq3xlGIjdfx-9V0CkbYYv6LU63i1B-suztAyK0Dx8udq2SYSX4TEeP5Erw021KZY8L4FEVNV3BOXaL/s1600/ExphormerPerformance.png" /&gt;&lt;/a&gt;&lt;/td&gt;&lt;/tr&gt;&lt;tr&gt;&lt;td class="tr-caption" style="text-align: center;"&gt;Results comparing Exphormer to standard GraphGPS on the five &lt;a href="https://arxiv.org/abs/2206.08164"&gt;Long Range Graph Benchmark&lt;/a&gt; datasets. We note that Exphormer achieved state-of-the-art results on four of the five datasets (PascalVOC-SP, COCO-SP, Peptides-Struct, PCQM-Contact) at the time of publication.&lt;/td&gt;&lt;/tr&gt;&lt;/tbody&gt;&lt;/table&gt;--&gt;

&lt;!--&lt;table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"&gt;
  &lt;tbody&gt;&lt;tr&gt;
   &lt;td align="left"&gt;&lt;strong&gt;Model&amp;nbsp;&lt;/strong&gt;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&amp;nbsp;PascalVOC-SP&amp;nbsp;&lt;/strong&gt;
&lt;br&gt;
     &amp;nbsp;&lt;font size="-1"&gt;F1 score &lt;/font&gt;&lt;strong&gt;↑&lt;/strong&gt;&amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&amp;nbsp;COCO-SP&amp;nbsp;&lt;/strong&gt;
&lt;br&gt;
     &amp;nbsp;&lt;font size="-1"&gt;F1 score &lt;/font&gt;&lt;strong&gt;↑&lt;/strong&gt;&amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&amp;nbsp;Peptides-Func&amp;nbsp;&lt;/strong&gt;
&lt;br&gt;
     &amp;nbsp;&lt;font size="-1"&gt;AP &lt;/font&gt;&lt;strong&gt;↑&lt;/strong&gt;&amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&amp;nbsp;Peptides-Struct&amp;nbsp;&lt;/strong&gt;
&lt;br&gt;
     &amp;nbsp;&lt;font size="-1"&gt;MAE &lt;/font&gt;&lt;strong&gt;↓&lt;/strong&gt;&amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&amp;nbsp;PCQM-Contact&lt;/strong&gt;
&lt;br&gt;
     &amp;nbsp;&lt;font size="-1"&gt;MRR &lt;/font&gt;&lt;strong&gt;↑&lt;/strong&gt;
   &lt;/td&gt;
  &lt;/tr&gt;
    &lt;tr&gt;&lt;td colspan="6"&gt;&lt;div style="line-height: 40%;"&gt;&lt;br /&gt;&lt;/div&gt;&lt;/td&gt;&lt;/tr&gt; 
  &lt;tr&gt;
    &lt;td&gt;Standard GraphGPS&amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&amp;nbsp;0.375 ± 0.011&amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&amp;nbsp;0.341 ± 0.004&amp;nbsp;
   &lt;/td&gt;
    &lt;td align="center"&gt;&amp;nbsp;&lt;strong&gt;0.654 ± 0.004&lt;/strong&gt; &amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&amp;nbsp;0.250 ± 0.001&amp;nbsp;
   &lt;/td&gt;
   &lt;td align="center"&gt;&amp;nbsp;0.334 ± 0.001
   &lt;/td&gt;
  &lt;/tr&gt;
  &lt;tr&gt;
   &lt;td&gt;&lt;em&gt;Exphormer (ours)&amp;nbsp;&lt;/em&gt;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&lt;em&gt;&amp;nbsp;0.398 ± 0.004&amp;nbsp;&lt;/em&gt;&lt;/strong&gt;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&lt;em&gt;&amp;nbsp;0.346 ± 0.001&amp;nbsp;&lt;/em&gt;&lt;/strong&gt;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;em&gt;&amp;nbsp;0.653 ± 0.004&amp;nbsp;&lt;/em&gt;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&lt;em&gt;&amp;nbsp;0.248 ± 0.001&amp;nbsp;&lt;/em&gt;&lt;/strong&gt;
   &lt;/td&gt;
   &lt;td align="center"&gt;&lt;strong&gt;&lt;em&gt;&amp;nbsp;0.364 ± 0.002&lt;/em&gt;&lt;/strong&gt;
   &lt;/td&gt;
  &lt;/tr&gt; 
&lt;/tbody&gt;&lt;/table&gt;--&gt; 


&lt;p&gt;
Finally, we observe that Exphormer, which creates an overlay graph of small diameter via expanders, exhibits the ability to effectively learn long-range dependencies. The &lt;a href="https://arxiv.org/abs/2206.08164"&gt;Long Range Graph Benchmark&lt;/a&gt;&amp;nbsp;is a suite of five graph learning datasets designed to measure the ability of models to capture long-range interactions. Results show that Exphormer-based models outperform standard GraphGPS models (which were previously state-of-the-art on four out of five datasets at the time of publication).
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Conclusion&lt;/h2&gt;


&lt;p&gt;
Graph transformers have emerged as an important architecture for ML that adapts the highly successful sequence-based transformers used in NLP to graph-structured data. Scalability has, however, proven to be a major challenge in enabling the use of graph transformers on datasets with large graphs. In this post, we have presented Exphormer, a sparse attention framework that uses expander graphs to improve scalability of graph transformers. Exphormer is shown to have important theoretical properties and exhibit strong empirical performance, particularly on datasets where it is crucial to learn long range dependencies. For more information, we point the reader to a short presentation &lt;a href="https://icml.cc/virtual/2023/poster/23782"&gt;video&lt;/a&gt; from ICML 2023.
&lt;/p&gt;
&lt;br /&gt; 

&lt;h2&gt;Acknowledgements&lt;/h2&gt;


&lt;p&gt;
&lt;em&gt;We thank our research collaborators Hamed Shirzad and Danica J. Sutherland from The University of British Columbia as well as Ali Kemal Sinop from Google Research. Special thanks to Tom Small for creating the animation used in this post.&lt;/em&gt;
&lt;/p&gt;</content><link href="http://blog.research.google/feeds/1418736582601940076/comments/default" rel="replies" title="Post Comments" type="application/atom+xml"/><link href="http://blog.research.google/2024/01/exphormer-scaling-transformers-for.html#comment-form" rel="replies" title="0 Comments" type="text/html"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1418736582601940076" rel="edit" type="application/atom+xml"/><link href="http://www.blogger.com/feeds/8474926331452026626/posts/default/1418736582601940076" rel="self" type="application/atom+xml"/><link href="http://blog.research.google/2024/01/exphormer-scaling-transformers-for.html" rel="alternate" title="Exphormer: Scaling transformers for graph-structured data" type="text/html"/><author><name>Google AI</name><uri>http://www.blogger.com/profile/12098626514775266161</uri><email>noreply@blogger.com</email><gd:image height="16" rel="http://schemas.google.com/g/2005#thumbnail" src="https://img1.blogblog.com/img/b16-rounded.gif" width="16"/></author><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" height="72" url="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhbovKreBr7RlKc4L36E6rLqiZBZzJSq5GLijCkomHREon5tYXd-7C2pppMXnL5Mj2d82kZGnPlarrrMzQOfRnN8kVvqDh1GnadIJ-hbaaS8VjYzCpaD-DgYor5cKx-OhTGZk9iCy5MjtwG2Q9eTyQiipDr5ViMdl2vkxfbLzWnB3wmLb8YfvVsTJ1FnOmw/s72-c/EXPHORMER%2005large.gif" width="72"/><thr:total>0</thr:total></entry></feed>