Same way you train any other model desu, the rest is setup, logging, saving, early stop etc
model = AutoencoderKL.from_pretrained(...)
for (inputs, targets) in training_dataloader:
outputs = model(inputs)
loss = loss_function(outputs, targets)
>the dataset structure
Images/pixel tensors, you could write your loader to use image files or you could preprocess to pixel tensors first saved as npy. Either way it's recommended to avoid many thousands of small individual files and use webdataset or if preprocessing concat then save as a batch
Usage of encode + decode at the resolution and batch you're training at, plus backward/optimizer. You can always adjust resolution and batch to fit your hardware, it's always about scale in the end.