In [6]:
import ee
import geemap
import datetime
import os
import numpy as np
import rasterio
import time
from sklearn.model_selection import train_test_split
!pip install -q geedim geemap
# --- 1. INITIALIZATION ---
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
# Configuration
year = 2021
START_DATE = f'{year}-10-16'
END_DATE = f'{year + 1}-04-16'
ASSET_ID = '[REDACTED_FOR_SECURITY]'
DRIVE_FOLDER = 'SatMAE_Scratch_Results_12Frames'
SAVE_DIR = f'/content/drive/MyDrive/{DRIVE_FOLDER}/'
if not os.path.exists(SAVE_DIR):
os.makedirs(SAVE_DIR)
# Assets
wheat_mask = ee.Image(ASSET_ID)
roi = wheat_mask.geometry()
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
MAX_CLOUD_PROB = 70
# Sentinel-2 Processing
s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
.filterDate(START_DATE, END_DATE).filterBounds(roi)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100))
.map(lambda img: img.updateMask(
img.select('MSK_SNWPRB').lt(1)
.add(img.select('SCL').gte(4).And(img.select('SCL').lte(6)))
)))
s2Clouds = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY').filterDate(START_DATE, END_DATE).filterBounds(roi)
s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
s2Raw, s2Clouds, ee.Filter.equals(leftField='system:index', rightField='system:index')))
def processS2(img):
isCloud = ee.Image(img.get('cloud_mask_img')).select('probability').gt(MAX_CLOUD_PROB)
optical = img.select(s2Bands).multiply(0.025)
ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
combined_mask = isCloud.Not()
return (optical.addBands(ndvi).updateMask(combined_mask).copyProperties(img, ['system:time_start']))
s2Clean = s2Joined.map(processS2)
# Fortnightly Composites
intervals = []
cur = datetime.datetime.fromisoformat(START_DATE)
end = datetime.datetime.fromisoformat(END_DATE)
while cur < end:
if cur.day == 1:
nxt = cur + datetime.timedelta(days=14)
str_lbl = f"{cur.year}_{cur.month:02d}_1"
else:
next_month = cur.replace(day=1) + datetime.timedelta(days=32)
nxt = next_month.replace(day=1)
str_lbl = f"{cur.year}_{cur.month:02d}_2"
if nxt > end: nxt = end
intervals.append([cur.strftime('%Y-%m-%d'), (nxt).strftime('%Y-%m-%d'), str_lbl])
cur = nxt
intervals = intervals[:12]
print(f"Generating {len(intervals)} Time Frames...")
def makeComposite(item):
start, end = ee.Date(ee.List(item).get(0)), ee.Date(ee.List(item).get(1))
return s2Clean.filterDate(start, end).qualityMosaic('NDVI').select(s2Bands)\
.clamp(0, 250).toByte().unmask(255)\
.set('fortnight_label', ee.List(item).get(2))
fortnightlyCol = ee.ImageCollection.fromImages(ee.List(intervals).map(makeComposite))
# --- ROBUST EXPORT LOGIC ---
def monitor_task(task):
print(f"Submitting Task: {task.config['description']}...")
try:
task.start()
except Exception as e:
if "already exists" in str(e):
print("Task already running/completed.")
return
raise e
while True:
status = task.status()
state = status['state']
if state in ['COMPLETED', 'FAILED', 'CANCELLED']:
print(f"\nTask Finished with state: {state}")
if state == 'FAILED':
print(f"Error Message: {status.get('error_message', 'Unknown Error')}")
raise RuntimeError("GEE Export Failed.")
break
print(f"Status: {state}...", end='\r')
time.sleep(10)
def wait_for_file(filepath, timeout=300): # INCREASED TO 5 MINUTES
"""Waits for a file to appear in Colab Drive mount with Forced Refresh."""
print(f"Waiting for Drive sync: {os.path.basename(filepath)}...")
start = time.time()
while not os.path.exists(filepath):
elapsed = time.time() - start
if elapsed > timeout:
print(f"\n TIMEOUT: File {filepath} not found locally.")
print("Check your Google Drive folder in a new tab.")
print("If the file exists there, simply RE-RUN this cell.")
raise FileNotFoundError(f"Timeout: {filepath} did not appear after {timeout}s.")
# AGGRESSIVE REFRESH LOGIC
# We list the directory contents to force the OS to check the drive again
try:
_ = os.listdir(os.path.dirname(filepath))
except:
pass
print(f"Syncing... ({int(elapsed)}s)", end='\r')
time.sleep(5)
print(f"\n Found: {filepath}")
return True
def generate_local_dataset():
print("Starting Robust Data Generation (Batch Export)...")
centroid = roi.centroid()
buffer_poly = centroid.buffer(2500).bounds()
# 1. Export Stack
stack_name = 'local_stack'
stack_file = os.path.join(SAVE_DIR, f'{stack_name}.tif')
# Only submit if file missing
if not os.path.exists(stack_file):
task_stack = ee.batch.Export.image.toDrive(
image=fortnightlyCol.toBands(),
description='Export_Stack_12Frames',
folder=DRIVE_FOLDER,
fileNamePrefix=stack_name,
region=buffer_poly,
scale=10,
crs='EPSG:4326',
fileFormat='GeoTIFF',
maxPixels=1e9
)
monitor_task(task_stack)
else:
print(f"Found existing stack: {stack_file}")
# 2. Export Mask
mask_name = 'local_mask'
mask_file = os.path.join(SAVE_DIR, f'{mask_name}.tif')
if not os.path.exists(mask_file):
task_mask = ee.batch.Export.image.toDrive(
image=wheat_mask,
description='Export_Mask',
folder=DRIVE_FOLDER,
fileNamePrefix=mask_name,
region=buffer_poly,
scale=10,
crs='EPSG:4326',
fileFormat='GeoTIFF',
maxPixels=1e9
)
monitor_task(task_mask)
else:
print(f"Found existing mask: {mask_file}")
# 3. Sync Wait (Crucial Step)
wait_for_file(stack_file)
wait_for_file(mask_file)
# 4. Process
print("Processing GeoTIFFs to Numpy...")
with rasterio.open(stack_file) as src_stack, rasterio.open(mask_file) as src_mask:
img_data = src_stack.read()
mask_data = src_mask.read(1)
H, W = mask_data.shape
if img_data.shape[0] < 120:
print(f"Warning: Got {img_data.shape[0]} bands. Padding with zeros.")
padding = np.zeros((120 - img_data.shape[0], H, W), dtype=img_data.dtype)
img_data = np.concatenate([img_data, padding], axis=0)
img_reshaped = img_data[:120].reshape(12, 10, H, W)
X_list, y_list = [], []
if H < 224 or W < 224:
print("ROI too small.")
return
else:
for r in range(0, H-224, 112):
for c in range(0, W-224, 112):
chip_x = img_reshaped[:, :, r:r+224, c:c+224]
chip_y = mask_data[r:r+224, c:c+224]
if np.mean(chip_x == 255) > 0.30: continue
if np.mean(chip_y > 0) > 0.01:
X_list.append(chip_x)
y_list.append(chip_y)
if len(X_list) == 0: raise RuntimeError("No valid tiles found.")
X = np.array(X_list).astype(np.uint8)
y = np.array(y_list).astype(np.uint8)[:, None, :, :]
split_idx = int(0.8 * len(X))
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]
print(f"Spatial Split: Train={len(X_train)}, Val={len(X_val)}")
np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
print("Done.")
generate_local_dataset()
Generating 12 Time Frames... Starting Robust Data Generation (Batch Export)... Submitting Task: Export_Stack_12Frames... Task Finished with state: COMPLETED Submitting Task: Export_Mask... Task Finished with state: COMPLETED Waiting for Drive sync: local_stack.tif... ❌ TIMEOUT: File /content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/local_stack.tif not found locally. Check your Google Drive folder in a new tab. If the file exists there, simply RE-RUN this cell.
--------------------------------------------------------------------------- FileNotFoundError Traceback (most recent call last) /tmp/ipython-input-3111439864.py in <cell line: 0>() 229 print("Done.") 230 --> 231 generate_local_dataset() /tmp/ipython-input-3111439864.py in generate_local_dataset() 179 180 # 3. Sync Wait (Crucial Step) --> 181 wait_for_file(stack_file) 182 wait_for_file(mask_file) 183 /tmp/ipython-input-3111439864.py in wait_for_file(filepath, timeout) 117 print("Check your Google Drive folder in a new tab.") 118 print("If the file exists there, simply RE-RUN this cell.") --> 119 raise FileNotFoundError(f"Timeout: {filepath} did not appear after {timeout}s.") 120 121 # AGGRESSIVE REFRESH LOGIC FileNotFoundError: Timeout: /content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/local_stack.tif did not appear after 300s.
In [1]:
import ee
import geemap
import os
import numpy as np
import rasterio
from google.colab import drive
# --- 1. MOUNT DRIVE ---
# This allows us to access the files you already saved
drive.mount('/content/drive')
# --- 2. INITIALIZE EARTH ENGINE ---
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Earth Engine Initialized.")
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Earth Engine Authenticated & Initialized.")
# --- 3. CONFIGURATION ---
DRIVE_FOLDER = 'SatMAE_Scratch_Results_12Frames'
SAVE_DIR = f'/content/drive/MyDrive/{DRIVE_FOLDER}/'
stack_file = os.path.join(SAVE_DIR, 'local_stack.tif')
mask_file = os.path.join(SAVE_DIR, 'local_mask.tif')
# --- 4. PROCESSING LOGIC (TIF -> NPY) ---
def process_existing_files():
if not os.path.exists(stack_file) or not os.path.exists(mask_file):
print(f" Error: Files not found in {SAVE_DIR}")
print("Please check your Google Drive folder name.")
return
print(f" Found Stack: {os.path.basename(stack_file)}")
print(f" Found Mask: {os.path.basename(mask_file)}")
print(" Processing into Training Data (Spatial Split + NoData Filter)...")
with rasterio.open(stack_file) as src_stack, rasterio.open(mask_file) as src_mask:
img_data = src_stack.read()
mask_data = src_mask.read(1)
# Padding Check (Ensure 120 channels: 12 frames * 10 bands)
H, W = mask_data.shape
if img_data.shape[0] < 120:
print(f" Warning: Got {img_data.shape[0]} bands. Padding with zeros.")
padding = np.zeros((120 - img_data.shape[0], H, W), dtype=img_data.dtype)
img_data = np.concatenate([img_data, padding], axis=0)
img_reshaped = img_data[:120].reshape(12, 10, H, W)
# Chip Generation
X_list, y_list = [], []
if H < 224 or W < 224:
print(" Error: ROI is smaller than 224x224.")
return
else:
# Stride 112 = 50% Overlap
for r in range(0, H-224, 112):
for c in range(0, W-224, 112):
chip_x = img_reshaped[:, :, r:r+224, c:c+224]
chip_y = mask_data[r:r+224, c:c+224]
# Filter Garbage (Discard if >30% is NoData/255)
if np.mean(chip_x == 255) > 0.30: continue
# Filter Empty Labels (Keep if >1% wheat)
if np.mean(chip_y > 0) > 0.01:
X_list.append(chip_x)
y_list.append(chip_y)
if len(X_list) == 0:
print(" No valid tiles found! (Mask might be empty or threshold too strict).")
return
# Spatial Split (Top 80% Train, Bottom 20% Val)
X = np.array(X_list).astype(np.uint8)
y = np.array(y_list).astype(np.uint8)[:, None, :, :]
split_idx = int(0.8 * len(X))
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]
print("-" * 30)
print(f" Data Ready!")
print(f" Train Tiles: {len(X_train)}")
print(f" Val Tiles: {len(X_val)}")
# Save to Drive
np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
print(" .npy files saved successfully.")
# Run the processing
process_existing_files()
Mounted at /content/drive Earth Engine Authenticated & Initialized. Found Stack: local_stack.tif Found Mask: local_mask.tif Processing into Training Data (Spatial Split + NoData Filter)... ------------------------------ ✅ Data Ready! Train Tiles: 9 Val Tiles: 3 💾 .npy files saved successfully.
In [3]:
import os
import numpy as np
import rasterio
from google.colab import drive
# --- 1. MOUNT DRIVE ---
# We need to see your Google Drive files
drive.mount('/content/drive', force_remount=True)
# --- 2. CONFIGURATION ---
# This must match exactly where your files are
DRIVE_FOLDER = 'SatMAE_Scratch_Results_12Frames'
SAVE_DIR = f'/content/drive/MyDrive/{DRIVE_FOLDER}/'
stack_file = os.path.join(SAVE_DIR, 'local_stack.tif')
mask_file = os.path.join(SAVE_DIR, 'local_mask.tif')
# --- 3. PROCESSING FUNCTION (No GEE required) ---
def generate_dataset_from_drive():
print(f" Checking folder: {SAVE_DIR}")
if not os.path.exists(stack_file) or not os.path.exists(mask_file):
print(f" Error: Files not found!")
print(f" Looking for: {stack_file}")
print(" Please check if the folder name is correct.")
return
print(f"Found Stack: {os.path.basename(stack_file)}")
print(f" Found Mask: {os.path.basename(mask_file)}")
print(" Processing GeoTIFFs into NPY (Training Data)...")
# Open the files directly from Drive
with rasterio.open(stack_file) as src_stack, rasterio.open(mask_file) as src_mask:
img_data = src_stack.read()
mask_data = src_mask.read(1)
# 1. Padding Check (Ensure 120 channels)
H, W = mask_data.shape
if img_data.shape[0] < 120:
print(f" Warning: Got {img_data.shape[0]} bands. Padding with zeros.")
padding = np.zeros((120 - img_data.shape[0], H, W), dtype=img_data.dtype)
img_data = np.concatenate([img_data, padding], axis=0)
img_reshaped = img_data[:120].reshape(12, 10, H, W)
# 2. Chip Generation
X_list, y_list = [], []
if H < 224 or W < 224:
print(" Error: Image is smaller than 224x224. Cannot create tiles.")
return
else:
# Stride 112 = 50% Overlap
for r in range(0, H-224, 112):
for c in range(0, W-224, 112):
chip_x = img_reshaped[:, :, r:r+224, c:c+224]
chip_y = mask_data[r:r+224, c:c+224]
# Filter Garbage (255)
if np.mean(chip_x == 255) > 0.30: continue
# Filter Empty Labels
if np.mean(chip_y > 0) > 0.01:
X_list.append(chip_x)
y_list.append(chip_y)
if len(X_list) == 0:
print(" No valid tiles found! (Check if mask has wheat pixels).")
return
# 3. Spatial Split (Top 80% Train, Bottom 20% Val)
X = np.array(X_list).astype(np.uint8)
y = np.array(y_list).astype(np.uint8)[:, None, :, :]
split_idx = int(0.8 * len(X))
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]
print("-" * 30)
print(f" Data Processed Successfully!")
print(f" Train Chips: {len(X_train)}")
print(f" Val Chips: {len(X_val)}")
# 4. Save NPY files back to Drive
np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
print(" .npy files saved.")
# --- RUN IT ---
generate_dataset_from_drive()
Mounted at /content/drive Checking folder: /content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/ Found Stack: local_stack.tif Found Mask: local_mask.tif Processing GeoTIFFs into NPY (Training Data)... ------------------------------ Data Processed Successfully! Train Chips: 9 Val Chips: 3 .npy files saved.
In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from tqdm.auto import tqdm
# --- 1. METRICS ---
def compute_metrics(pred_probs, targets, threshold=0.5):
pred_mask = (pred_probs > threshold).float()
intersection = (pred_mask * targets).sum()
union = pred_mask.sum() + targets.sum() - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
return iou.item(), dice.item()
# --- 2. LOSS FUNCTIONS ---
class FastHausdorffLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, gt):
probs = torch.sigmoid(pred)
# Morphological Erosion via MaxPool
# Invert -> MaxPool -> Invert = Erosion
p_eroded = -F.max_pool2d(-probs, kernel_size=3, stride=1, padding=1)
t_eroded = -F.max_pool2d(-gt, kernel_size=3, stride=1, padding=1)
p_edge = probs - p_eroded
t_edge = gt - t_eroded
# --- MATH FIX: SCALING ---
# Edge differences are sparse (mostly 0). The mean is tiny.
# We multiply by 20.0 to bring the magnitude to ~0.1-0.5 range
# so it balances with Dice Loss.
return (p_edge - t_edge).abs().mean() * 20.0
class CompoundLoss(nn.Module):
def forward(self, inputs, targets):
# Dice
probs = torch.sigmoid(inputs)
inter = (probs * targets).sum()
dice_loss = 1 - (2. * inter / (probs.sum() + targets.sum() + 1e-6))
# Fast Hausdorff
fast_hd = FastHausdorffLoss()(inputs, targets)
# 70% Overlap, 30% Shape
return 0.7 * dice_loss + 0.3 * fast_hd
class TestTimeAugmentation:
def __init__(self, model): self.model = model
def apply(self, x):
preds = []
with torch.no_grad():
preds.append(torch.sigmoid(self.model(x)))
preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-1]))), [-1]))
preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2]))), [-2]))
preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2,-1]))), [-2,-1]))
return torch.stack(preds).mean(dim=0)
# --- 3. ARCHITECTURE (Standard SatMAE) ---
class SatMAEPatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=10, embed_dim=768):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.reshape(B * T, C, H, W)
x = self.proj(x).flatten(2).transpose(1, 2)
return x.reshape(B, T, -1, 768)
class SatMAEEncoder(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.patch_embed = SatMAEPatchEmbed(img_size, patch_size, in_chans, embed_dim)
self.pos_embed_spatial = nn.Parameter(torch.zeros(1, 1, self.patch_embed.num_patches, embed_dim))
self.pos_embed_temporal = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
self.blocks = nn.TransformerEncoder(nn.TransformerEncoderLayer(embed_dim, num_heads, int(embed_dim*4), 0.1, 'gelu', batch_first=True), depth)
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0); nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
B, T, C, H, W = x.shape
x = self.patch_embed(x)
x = x + self.pos_embed_spatial + self.pos_embed_temporal
x = x.reshape(B, T * x.shape[2], -1)
x = self.norm(self.blocks(x))
return x.reshape(B, T, 14, 14, -1)
class SatMAESegmentation(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768):
super().__init__()
self.num_frames = num_frames; self.chans = in_chans
self.encoder = SatMAEEncoder(img_size, patch_size, in_chans, num_frames, embed_dim)
self.up1 = nn.ConvTranspose2d(embed_dim, 256, 2, 2)
self.conv1 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU())
self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
self.conv2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU())
self.up3 = nn.ConvTranspose2d(128, 64, 2, 2)
self.conv3 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU())
self.up4 = nn.ConvTranspose2d(64, 32, 2, 2)
self.conv4 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU())
self.final = nn.Conv2d(32, 1, 1)
def forward(self, x):
if x.ndim == 4: B, _, H, W = x.shape; x = x.reshape(B, self.num_frames, self.chans, H, W)
features = self.encoder(x)
x = features.mean(dim=1).permute(0, 3, 1, 2)
x = self.conv1(self.up1(x))
x = self.conv2(self.up2(x))
x = self.conv3(self.up3(x))
x = self.conv4(self.up4(x))
return self.final(x)
# --- 4. TRAINER ---
class SatMAETrainer:
def __init__(self, model, loaders, device, lr=1e-4):
self.model = model; self.loaders = loaders; self.device = device
self.optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=150)
self.swa_model = AveragedModel(model)
self.swa_scheduler = SWALR(self.optimizer, swa_lr=lr)
self.criterion = CompoundLoss()
self.tta = TestTimeAugmentation(model)
self.history = {'train_loss': [], 'val_iou': []}
def fit(self, epochs=200):
print(f"Starting Scratch Training ({epochs} epochs)...")
use_swa = False; swa_start_epoch = 151; patience = 30; no_improv = 0; best_iou = 0
for ep in range(epochs):
self.model.train()
t_loss = 0
pbar = tqdm(self.loaders['train'], desc=f"Epoch {ep+1}/{epochs}", leave=False)
for batch in pbar:
if isinstance(batch, (list, tuple)): x, y = batch[0], batch[1]
else: x, y = batch
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
logits = self.model(x)
loss = self.criterion(logits, y)
loss.backward()
self.optimizer.step()
t_loss += loss.item()
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
# Validation
eval_model = self.swa_model if use_swa else self.model
eval_model.eval()
v_iou_sum = 0
with torch.no_grad():
for batch in self.loaders['val']:
if isinstance(batch, (list, tuple)): x, y = batch[0], batch[1]
else: x, y = batch
x, y = x.to(self.device), y.to(self.device)
if use_swa: probs = torch.sigmoid(eval_model(x))
else: probs = self.tta.apply(x)
iou, _ = compute_metrics(probs, y)
v_iou_sum += iou
avg_loss = t_loss / len(self.loaders['train'])
avg_iou = v_iou_sum / len(self.loaders['val']) if len(self.loaders['val']) > 0 else 0
self.history['train_loss'].append(avg_loss)
self.history['val_iou'].append(avg_iou)
print(f"Ep {ep+1} | Loss: {avg_loss:.4f} | Val IoU: {avg_iou:.4f} | SWA: {use_swa}")
if (ep + 1) % 5 == 0: torch.save(self.model.state_dict(), f"checkpoint_ep{ep+1}.pth")
if not use_swa:
self.scheduler.step()
if avg_iou > best_iou:
best_iou = avg_iou; no_improv = 0
torch.save(self.model.state_dict(), 'best_model.pth')
print(f" >>> New Best IoU: {best_iou:.4f}")
else:
no_improv += 1
if no_improv >= patience:
print(f" !! No Improvement. Triggering SWA."); use_swa = True; swa_start_epoch = ep + 1
if ep + 1 >= swa_start_epoch: use_swa = True
else:
self.swa_model.update_parameters(self.model); self.swa_scheduler.step()
if use_swa:
update_bn(self.loaders['train'], self.swa_model, device=self.device)
torch.save(self.swa_model.module.state_dict(), 'final_swa_model.pth')
print("Training Complete. SWA Model Saved.")
In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint # <--- CRITICAL IMPORT FOR MEMORY
# --- 1. METRICS ---
def compute_metrics(pred_probs, targets, threshold=0.5):
pred_mask = (pred_probs > threshold).float()
intersection = (pred_mask * targets).sum()
union = pred_mask.sum() + targets.sum() - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
return iou.item(), dice.item()
# --- 2. LOSS FUNCTIONS ---
class FastHausdorffLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, gt):
probs = torch.sigmoid(pred)
p_eroded = -F.max_pool2d(-probs, kernel_size=3, stride=1, padding=1)
t_eroded = -F.max_pool2d(-gt, kernel_size=3, stride=1, padding=1)
return (probs - p_eroded - (gt - t_eroded)).abs().mean() * 20.0
class CompoundLoss(nn.Module):
def forward(self, inputs, targets):
probs = torch.sigmoid(inputs)
inter = (probs * targets).sum()
dice = 1 - (2. * inter / (probs.sum() + targets.sum() + 1e-6))
hd = FastHausdorffLoss()(inputs, targets)
return 0.7 * dice + 0.3 * hd
class TestTimeAugmentation:
def __init__(self, model): self.model = model
def apply(self, x):
preds = []
with torch.no_grad():
preds.append(torch.sigmoid(self.model(x)))
preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-1]))), [-1]))
preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2]))), [-2]))
preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2,-1]))), [-2,-1]))
return torch.stack(preds).mean(dim=0)
# --- 3. ARCHITECTURE (Memory Optimized) ---
class SatMAEPatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=10, embed_dim=768):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.reshape(B * T, C, H, W)
x = self.proj(x).flatten(2).transpose(1, 2)
return x.reshape(B, T, -1, 768)
class SatMAEEncoder(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.patch_embed = SatMAEPatchEmbed(img_size, patch_size, in_chans, embed_dim)
self.pos_embed_spatial = nn.Parameter(torch.zeros(1, 1, self.patch_embed.num_patches, embed_dim))
self.pos_embed_temporal = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
# CHANGED: Use ModuleList to enable Checkpointing
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(embed_dim, num_heads, int(embed_dim*4), 0.1, 'gelu', batch_first=True)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0); nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
B, T, C, H, W = x.shape
x = self.patch_embed(x)
x = x + self.pos_embed_spatial + self.pos_embed_temporal
x = x.reshape(B, T * x.shape[2], -1)
# --- GRADIENT CHECKPOINTING LOGIC ---
for blk in self.blocks:
if self.training:
# This saves massive memory by not storing activations
x = checkpoint(blk, x, use_reentrant=False)
else:
x = blk(x)
x = self.norm(x)
return x.reshape(B, T, 14, 14, -1)
class SatMAESegmentation(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768):
super().__init__()
self.num_frames = num_frames; self.chans = in_chans
self.encoder = SatMAEEncoder(img_size, patch_size, in_chans, num_frames, embed_dim)
self.up1 = nn.ConvTranspose2d(embed_dim, 256, 2, 2)
self.conv1 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU())
self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
self.conv2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU())
self.up3 = nn.ConvTranspose2d(128, 64, 2, 2)
self.conv3 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU())
self.up4 = nn.ConvTranspose2d(64, 32, 2, 2)
self.conv4 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU())
self.final = nn.Conv2d(32, 1, 1)
def forward(self, x):
if x.ndim == 4: B, _, H, W = x.shape; x = x.reshape(B, self.num_frames, self.chans, H, W)
features = self.encoder(x)
x = features.mean(dim=1).permute(0, 3, 1, 2)
x = self.conv1(self.up1(x))
x = self.conv2(self.up2(x))
x = self.conv3(self.up3(x))
x = self.conv4(self.up4(x))
return self.final(x)
In [5]:
import torch
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
import os
import csv
import pandas as pd
import matplotlib.pyplot as plt
# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
BATCH_SIZE = 16
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50
# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
# --- SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Model (12 Frames, 12 Channels)
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=12, num_frames=12).to(device)
criterion = CompoundLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
loaders = {
'train': DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
'val': DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
}
# --- RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}
if os.path.exists(CHECKPOINT_PATH):
print(" Found checkpoint. Resuming...")
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_iou = ckpt['best_iou']
patience_counter = ckpt['patience_counter']
swa_active = ckpt['swa_active']
swa_epoch_counter = ckpt['swa_epoch_counter']
if swa_active:
swa_model.load_state_dict(ckpt['swa_state_dict'])
swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])
# Load history from CSV
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
history['epoch'] = df['epoch'].tolist()
history['train_loss'] = df['train_loss'].tolist()
history['val_loss'] = df['val_loss'].tolist()
history['val_iou'] = df['val_iou'].tolist()
print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
# --- TRAINING LOOP ---
print(f" Starting Training. Max: {MAX_EPOCHS} Eps. Patience: {PATIENCE_TRIGGER}")
try:
for ep in range(start_epoch, MAX_EPOCHS):
model.train()
train_loss = 0
# Training Step
for x, y in tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
preds = model(x)
loss = criterion(preds, y)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation Step
eval_model = swa_model if swa_active else model
eval_model.eval()
val_loss = 0
val_iou_sum = 0
with torch.no_grad():
for x, y in loaders['val']:
x, y = x.to(device), y.to(device)
# Forward
if swa_active: preds = eval_model(x)
else: preds = model(x) # No TTA for speed during training loop
val_loss += criterion(preds, y).item()
# Metrics
probs = torch.sigmoid(preds)
iou, _ = compute_metrics(probs, y)
val_iou_sum += iou
# Stats
avg_t = train_loss / len(loaders['train'])
avg_v = val_loss / len(loaders['val'])
avg_iou = val_iou_sum / len(loaders['val'])
# Log to History
history['epoch'].append(ep+1)
history['train_loss'].append(avg_t)
history['val_loss'].append(avg_v)
history['val_iou'].append(avg_iou)
# Append to CSV immediately (Crash Safety)
with open(LOG_PATH, 'a', newline='') as f:
writer = csv.writer(f)
if ep == 0 and not os.path.exists(LOG_PATH):
writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
writer.writerow([ep+1, avg_t, avg_v, avg_iou])
# Status Message
status = ""
# --- LOGIC BRANCHING ---
if swa_active:
# SWA PHASE
swa_model.update_parameters(model)
swa_scheduler.step()
swa_epoch_counter += 1
status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"
# Stop Condition
if swa_epoch_counter >= SWA_DURATION:
print(f"Epoch {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
print(" SWA Complete. Saving Final Model.")
update_bn(loaders['train'], swa_model, device=device)
torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
break # EXIT LOOP
else:
# NORMAL PHASE
scheduler.step()
if avg_iou > best_iou:
best_iou = avg_iou
patience_counter = 0
torch.save(model.state_dict(), BEST_MODEL_PATH)
status = f" Best IoU!"
else:
patience_counter += 1
status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"
# Trigger SWA?
if patience_counter >= PATIENCE_TRIGGER:
print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
swa_active = True
swa_epoch_counter = 0
# Load best weights before starting SWA to ensure stability
model.load_state_dict(torch.load(BEST_MODEL_PATH))
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
torch.save({
'epoch': ep,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'swa_state_dict': swa_model.state_dict() if swa_active else None,
'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
'best_iou': best_iou,
'patience_counter': patience_counter,
'swa_active': swa_active,
'swa_epoch_counter': swa_epoch_counter
}, CHECKPOINT_PATH)
except KeyboardInterrupt:
print("Training Interrupted. Checkpoint Saved.")
# Plot Results
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.legend()
plt.title("Training Curves")
plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
plt.show()
Device: cuda
--------------------------------------------------------------------------- NameError Traceback (most recent call last) /tmp/ipython-input-2473747799.py in <cell line: 0>() 37 38 loaders = { ---> 39 'train': DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2), 40 'val': DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 41 } NameError: name 'train_ds' is not defined
In [6]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
BATCH_SIZE = 16
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50
# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
# --- 1. DATASET DEFINITION (INCLUDED HERE TO PREVENT ERRORS) ---
class SatMAEDataset(Dataset):
def __init__(self, x_path, y_path):
# Load in mmap mode (Zero RAM usage initially)
self.x_data = np.load(x_path, mmap_mode='r')
self.y_data = np.load(y_path, mmap_mode='r')
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx):
# Read into RAM
x_np = self.x_data[idx]
y_np = self.y_data[idx]
# Fast Numpy Scaling
x_float = x_np.astype(np.float32)
y_float = y_np.astype(np.float32)
# Zero out '255' (NoData)
x_float[x_np == 255] = 0.0
y_float[y_np == 255] = 0.0
# Scale 0-250 -> 0-1
x_float /= 250.0
return torch.from_numpy(x_float), torch.from_numpy(y_float)
# --- 2. SETUP LOADERS ---
try:
print("Initializing Datasets...")
train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
loaders = {
'train': DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
'val': DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
}
print(" Loaders Ready.")
except Exception as e:
print(f" Error loading data: {e}")
print("Did you run Cell 1 to generate the .npy files?")
raise e
# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Dependency Check: Ensure Cell 2 was run
try:
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
criterion = CompoundLoss()
except NameError:
raise NameError(" Error: 'SatMAESegmentation' or 'CompoundLoss' not defined. Please RUN CELL 2 (Architecture) first.")
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
# --- 4. RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}
if os.path.exists(CHECKPOINT_PATH):
print(" Found checkpoint. Resuming...")
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_iou = ckpt['best_iou']
patience_counter = ckpt['patience_counter']
swa_active = ckpt['swa_active']
swa_epoch_counter = ckpt['swa_epoch_counter']
if swa_active:
swa_model.load_state_dict(ckpt['swa_state_dict'])
swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
history['epoch'] = df['epoch'].tolist()
history['train_loss'] = df['train_loss'].tolist()
history['val_loss'] = df['val_loss'].tolist()
history['val_iou'] = df['val_iou'].tolist()
print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
# --- 5. TRAINING LOOP ---
print(f" Starting Training. Max: {MAX_EPOCHS} Eps. Patience: {PATIENCE_TRIGGER}")
try:
for ep in range(start_epoch, MAX_EPOCHS):
model.train()
train_loss = 0
# Training Step
for x, y in tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
preds = model(x)
loss = criterion(preds, y)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation Step
eval_model = swa_model if swa_active else model
eval_model.eval()
val_loss = 0
val_iou_sum = 0
with torch.no_grad():
for x, y in loaders['val']:
x, y = x.to(device), y.to(device)
if swa_active: preds = eval_model(x)
else: preds = model(x)
val_loss += criterion(preds, y).item()
probs = torch.sigmoid(preds)
# Helper metric calc
pred_mask = (probs > 0.5).float()
inter = (pred_mask * y).sum()
union = pred_mask.sum() + y.sum() - inter
iou = (inter + 1e-6) / (union + 1e-6)
val_iou_sum += iou.item()
# Stats
avg_t = train_loss / len(loaders['train'])
avg_v = val_loss / len(loaders['val'])
avg_iou = val_iou_sum / len(loaders['val'])
# Log to History
history['epoch'].append(ep+1)
history['train_loss'].append(avg_t)
history['val_loss'].append(avg_v)
history['val_iou'].append(avg_iou)
# Append to CSV
with open(LOG_PATH, 'a', newline='') as f:
writer = csv.writer(f)
if ep == 0 and not os.path.exists(LOG_PATH):
writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
writer.writerow([ep+1, avg_t, avg_v, avg_iou])
# Status & Logic
status = ""
if swa_active:
swa_model.update_parameters(model)
swa_scheduler.step()
swa_epoch_counter += 1
status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"
if swa_epoch_counter >= SWA_DURATION:
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
print("SWA Complete. Saving Final Model.")
update_bn(loaders['train'], swa_model, device=device)
torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
break
else:
scheduler.step()
if avg_iou > best_iou:
best_iou = avg_iou
patience_counter = 0
torch.save(model.state_dict(), BEST_MODEL_PATH)
status = f" Best IoU!"
else:
patience_counter += 1
status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"
if patience_counter >= PATIENCE_TRIGGER:
print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
swa_active = True
swa_epoch_counter = 0
model.load_state_dict(torch.load(BEST_MODEL_PATH))
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
# Save Checkpoint
torch.save({
'epoch': ep,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'swa_state_dict': swa_model.state_dict() if swa_active else None,
'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
'best_iou': best_iou,
'patience_counter': patience_counter,
'swa_active': swa_active,
'swa_epoch_counter': swa_epoch_counter
}, CHECKPOINT_PATH)
except KeyboardInterrupt:
print("Training Interrupted. Checkpoint Saved.")
# Plot Results
if len(history['train_loss']) > 0:
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.legend()
plt.title("Training Curves")
plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
plt.show()
Initializing Datasets... Loaders Ready. Device: cuda Starting Training. Max: 5000 Eps. Patience: 100
Ep 1: 0%| | 0/1 [00:00<?, ?it/s]
--------------------------------------------------------------------------- OutOfMemoryError Traceback (most recent call last) /tmp/ipython-input-581224262.py in <cell line: 0>() 129 x, y = x.to(device), y.to(device) 130 optimizer.zero_grad() --> 131 preds = model(x) 132 loss = criterion(preds, y) 133 loss.backward() /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /tmp/ipython-input-3544700008.py in forward(self, x) 124 def forward(self, x): 125 if x.ndim == 4: B, _, H, W = x.shape; x = x.reshape(B, self.num_frames, self.chans, H, W) --> 126 features = self.encoder(x) 127 x = features.mean(dim=1).permute(0, 3, 1, 2) 128 x = self.conv1(self.up1(x)) /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /tmp/ipython-input-3544700008.py in forward(self, x) 103 x = x + self.pos_embed_spatial + self.pos_embed_temporal 104 x = x.reshape(B, T * x.shape[2], -1) --> 105 x = self.norm(self.blocks(x)) 106 return x.reshape(B, T, 14, 14, -1) 107 /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py in forward(self, src, mask, src_key_padding_mask, is_causal) 522 523 for mod in self.layers: --> 524 output = mod( 525 output, 526 src_mask=mask, /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py in forward(self, src, src_mask, src_key_padding_mask, is_causal) 935 + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal) 936 ) --> 937 x = self.norm2(x + self._ff_block(x)) 938 939 return x /usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py in _ff_block(self, x) 960 # feed forward block 961 def _ff_block(self, x: Tensor) -> Tensor: --> 962 x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 963 return self.dropout2(x) 964 OutOfMemoryError: CUDA out of memory. Tried to allocate 250.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 180.12 MiB is free. Process 4969 has 14.56 GiB memory in use. Of the allocated memory 14.08 GiB is allocated by PyTorch, and 367.47 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
In [7]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast # <--- NEW: For Mixed Precision
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc
# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
# --- MEMORY OPTIMIZATION SETTINGS ---
PHYSICAL_BATCH_SIZE = 4 # Small enough to fit in GPU
ACCUM_STEPS = 4 # 4 * 4 = 16 (Effective Batch Size)
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50
# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
# --- 1. DATASET DEFINITION ---
class SatMAEDataset(Dataset):
def __init__(self, x_path, y_path):
self.x_data = np.load(x_path, mmap_mode='r')
self.y_data = np.load(y_path, mmap_mode='r')
def __len__(self): return len(self.x_data)
def __getitem__(self, idx):
x_np = self.x_data[idx]
y_np = self.y_data[idx]
x_float = x_np.astype(np.float32)
y_float = y_np.astype(np.float32)
x_float[x_np == 255] = 0.0
y_float[y_np == 255] = 0.0
x_float /= 250.0
return torch.from_numpy(x_float), torch.from_numpy(y_float)
# --- 2. SETUP LOADERS ---
try:
print("Initializing Datasets...")
train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
loaders = {
'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True),
'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
}
print(" Loaders Ready.")
except Exception as e:
print(f" Error loading data: {e}")
raise e
# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Clear Cache before starting
torch.cuda.empty_cache()
gc.collect()
try:
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
criterion = CompoundLoss()
except NameError:
raise NameError(" Error: 'SatMAESegmentation' not defined. Run Cell 2 first.")
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler() # <--- NEW: Handles FP16 gradients
# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
# --- 4. RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}
if os.path.exists(CHECKPOINT_PATH):
print(" Found checkpoint. Resuming...")
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_iou = ckpt['best_iou']
patience_counter = ckpt['patience_counter']
swa_active = ckpt['swa_active']
swa_epoch_counter = ckpt['swa_epoch_counter']
if swa_active:
swa_model.load_state_dict(ckpt['swa_state_dict'])
swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
history['epoch'] = df['epoch'].tolist()
history['train_loss'] = df['train_loss'].tolist()
history['val_loss'] = df['val_loss'].tolist()
history['val_iou'] = df['val_iou'].tolist()
print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
# --- 5. TRAINING LOOP (OPTIMIZED) ---
print(f"Starting Training. Max: {MAX_EPOCHS} Eps. Batch: {PHYSICAL_BATCH_SIZE} (Accum to 16)")
try:
for ep in range(start_epoch, MAX_EPOCHS):
model.train()
train_loss = 0
optimizer.zero_grad() # Initialize gradients
# Training Step
pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
for i, (x, y) in enumerate(pbar):
x, y = x.to(device), y.to(device)
# Mixed Precision Forward
with autocast(enabled=True):
preds = model(x)
loss = criterion(preds, y)
loss = loss / ACCUM_STEPS # Normalize loss
# Backward
scaler.scale(loss).backward()
# Update weights every ACCUM_STEPS
if (i + 1) % ACCUM_STEPS == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
train_loss += loss.item() * ACCUM_STEPS # Scale back up for logging
# Validation Step
eval_model = swa_model if swa_active else model
eval_model.eval()
val_loss = 0
val_iou_sum = 0
with torch.no_grad():
for x, y in loaders['val']:
x, y = x.to(device), y.to(device)
with autocast(enabled=True):
if swa_active: preds = eval_model(x)
else: preds = model(x)
val_loss += criterion(preds, y).item()
probs = torch.sigmoid(preds)
pred_mask = (probs > 0.5).float()
inter = (pred_mask * y).sum()
union = pred_mask.sum() + y.sum() - inter
iou = (inter + 1e-6) / (union + 1e-6)
val_iou_sum += iou.item()
# Stats
avg_t = train_loss / len(loaders['train'])
avg_v = val_loss / len(loaders['val'])
avg_iou = val_iou_sum / len(loaders['val'])
# Append to CSV
with open(LOG_PATH, 'a', newline='') as f:
writer = csv.writer(f)
if ep == 0 and not os.path.exists(LOG_PATH):
writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
writer.writerow([ep+1, avg_t, avg_v, avg_iou])
# Status & Logic
status = ""
if swa_active:
swa_model.update_parameters(model)
swa_scheduler.step()
swa_epoch_counter += 1
status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"
if swa_epoch_counter >= SWA_DURATION:
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
print(" SWA Complete. Saving Final Model.")
update_bn(loaders['train'], swa_model, device=device)
torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
break
else:
scheduler.step()
if avg_iou > best_iou:
best_iou = avg_iou
patience_counter = 0
torch.save(model.state_dict(), BEST_MODEL_PATH)
status = f" Best IoU!"
else:
patience_counter += 1
status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"
if patience_counter >= PATIENCE_TRIGGER:
print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
swa_active = True
swa_epoch_counter = 0
model.load_state_dict(torch.load(BEST_MODEL_PATH))
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
# Save Checkpoint
torch.save({
'epoch': ep, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'swa_state_dict': swa_model.state_dict() if swa_active else None,
'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
'best_iou': best_iou, 'patience_counter': patience_counter,
'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
}, CHECKPOINT_PATH)
except KeyboardInterrupt:
print("Training Interrupted. Checkpoint Saved.")
# Plot Results
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
if len(df) > 0:
plt.figure(figsize=(10, 5))
plt.plot(df['train_loss'], label='Train Loss')
plt.plot(df['val_loss'], label='Val Loss')
plt.legend()
plt.title("Training Curves")
plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
plt.show()
Exception ignored in: <function _ConnectionBase.__del__ at 0x78fc142d4540> Traceback (most recent call last): File "/usr/lib/python3.12/multiprocessing/connection.py", line 133, in __del__ File "/usr/lib/python3.12/multiprocessing/connection.py", line 377, in _close OSError: [Errno 9] Bad file descriptor
Initializing Datasets... Loaders Ready. Device: cuda
/tmp/ipython-input-2721185616.py:77: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler() # <--- NEW: Handles FP16 gradients
--------------------------------------------------------------------------- OutOfMemoryError Traceback (most recent call last) /tmp/ipython-input-2721185616.py in <cell line: 0>() 78 79 # SWA Setup ---> 80 swa_model = AveragedModel(model) 81 swa_scheduler = SWALR(optimizer, swa_lr=5e-5) 82 /usr/local/lib/python3.12/dist-packages/torch/optim/swa_utils.py in __init__(self, model, device, avg_fn, multi_avg_fn, use_buffers) 232 "Only one of avg_fn and multi_avg_fn should be provided" 233 ) --> 234 self.module = deepcopy(model) 235 if device is not None: 236 self.module = self.module.to(device) /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 160 y = x 161 else: --> 162 y = _reconstruct(x, memo, *rv) 163 164 # If is its own copy, don't memoize. /usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 257 if state is not None: 258 if deep: --> 259 state = deepcopy(state, memo) 260 if hasattr(y, '__setstate__'): 261 y.__setstate__(state) /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 160 y = x 161 else: --> 162 y = _reconstruct(x, memo, *rv) 163 164 # If is its own copy, don't memoize. /usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 257 if state is not None: 258 if deep: --> 259 state = deepcopy(state, memo) 260 if hasattr(y, '__setstate__'): 261 y.__setstate__(state) /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 160 y = x 161 else: --> 162 y = _reconstruct(x, memo, *rv) 163 164 # If is its own copy, don't memoize. /usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 257 if state is not None: 258 if deep: --> 259 state = deepcopy(state, memo) 260 if hasattr(y, '__setstate__'): 261 y.__setstate__(state) /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 160 y = x 161 else: --> 162 y = _reconstruct(x, memo, *rv) 163 164 # If is its own copy, don't memoize. /usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 257 if state is not None: 258 if deep: --> 259 state = deepcopy(state, memo) 260 if hasattr(y, '__setstate__'): 261 y.__setstate__(state) /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 160 y = x 161 else: --> 162 y = _reconstruct(x, memo, *rv) 163 164 # If is its own copy, don't memoize. /usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 257 if state is not None: 258 if deep: --> 259 state = deepcopy(state, memo) 260 if hasattr(y, '__setstate__'): 261 y.__setstate__(state) /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 160 y = x 161 else: --> 162 y = _reconstruct(x, memo, *rv) 163 164 # If is its own copy, don't memoize. /usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 257 if state is not None: 258 if deep: --> 259 state = deepcopy(state, memo) 260 if hasattr(y, '__setstate__'): 261 y.__setstate__(state) /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 134 copier = _deepcopy_dispatch.get(cls) 135 if copier is not None: --> 136 y = copier(x, memo) 137 else: 138 if issubclass(cls, type): /usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy) 219 memo[id(x)] = y 220 for key, value in x.items(): --> 221 y[deepcopy(key, memo)] = deepcopy(value, memo) 222 return y 223 d[dict] = _deepcopy_dict /usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil) 141 copier = getattr(x, "__deepcopy__", None) 142 if copier is not None: --> 143 y = copier(memo) 144 else: 145 reductor = dispatch_table.get(cls) /usr/local/lib/python3.12/dist-packages/torch/nn/parameter.py in __deepcopy__(self, memo) 77 else: 78 result = type(self)( ---> 79 self.data.clone(memory_format=torch.preserve_format), self.requires_grad 80 ) 81 memo[id(self)] = result OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 8.12 MiB is free. Process 4969 has 14.73 GiB memory in use. Of the allocated memory 14.28 GiB is allocated by PyTorch, and 331.81 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
In [8]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc
# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
# --- ULTRA-LOW MEMORY SETTINGS ---
PHYSICAL_BATCH_SIZE = 2 # <--- REDUCED to 2 to fix OOM
ACCUM_STEPS = 8 # <--- INCREASED to 8 (2 * 8 = 16 Effective Batch)
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50
# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
# --- 1. DATASET DEFINITION ---
class SatMAEDataset(Dataset):
def __init__(self, x_path, y_path):
self.x_data = np.load(x_path, mmap_mode='r')
self.y_data = np.load(y_path, mmap_mode='r')
def __len__(self): return len(self.x_data)
def __getitem__(self, idx):
x_np = self.x_data[idx]
y_np = self.y_data[idx]
x_float = x_np.astype(np.float32)
y_float = y_np.astype(np.float32)
x_float[x_np == 255] = 0.0
y_float[y_np == 255] = 0.0
x_float /= 250.0
return torch.from_numpy(x_float), torch.from_numpy(y_float)
# --- 2. SETUP LOADERS ---
try:
print("Initializing Datasets...")
train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
loaders = {
'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True),
'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
}
print(" Loaders Ready.")
except Exception as e:
print(f" Error loading data: {e}")
raise e
# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Aggressive Memory Cleanup
torch.cuda.empty_cache()
gc.collect()
try:
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
criterion = CompoundLoss()
except NameError:
raise NameError(" Error: 'SatMAESegmentation' not defined. Run Cell 2 first.")
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler()
# SWA Setup (Initialize on CPU first to save VRAM, move to GPU only when needed if possible,
# but for simplicity we keep on GPU. With batch=2 it should fit.)
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
# --- 4. RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}
if os.path.exists(CHECKPOINT_PATH):
print(" Found checkpoint. Resuming...")
try:
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_iou = ckpt['best_iou']
patience_counter = ckpt['patience_counter']
swa_active = ckpt['swa_active']
swa_epoch_counter = ckpt['swa_epoch_counter']
if swa_active:
swa_model.load_state_dict(ckpt['swa_state_dict'])
swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
history['epoch'] = df['epoch'].tolist()
history['train_loss'] = df['train_loss'].tolist()
history['val_loss'] = df['val_loss'].tolist()
history['val_iou'] = df['val_iou'].tolist()
print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
except Exception as e:
print(f" Warning: Checkpoint corrupt or mismatch. Starting fresh. ({e})")
# --- 5. TRAINING LOOP ---
print(f" Starting Training. Max: {MAX_EPOCHS} Eps. Batch: {PHYSICAL_BATCH_SIZE} (Accum to 16)")
try:
for ep in range(start_epoch, MAX_EPOCHS):
model.train()
train_loss = 0
optimizer.zero_grad(set_to_none=True) # Saves Memory
# Training Step
pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
for i, (x, y) in enumerate(pbar):
x, y = x.to(device), y.to(device)
# Mixed Precision Forward
with autocast(enabled=True):
preds = model(x)
loss = criterion(preds, y)
loss = loss / ACCUM_STEPS
# Backward
scaler.scale(loss).backward()
# Update weights every ACCUM_STEPS
if (i + 1) % ACCUM_STEPS == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
train_loss += loss.item() * ACCUM_STEPS
# Validation Step
eval_model = swa_model if swa_active else model
eval_model.eval()
val_loss = 0
val_iou_sum = 0
# Explicit Memory Cleanup before Validation
torch.cuda.empty_cache()
with torch.no_grad():
for x, y in loaders['val']:
x, y = x.to(device), y.to(device)
with autocast(enabled=True):
if swa_active: preds = eval_model(x)
else: preds = model(x)
val_loss += criterion(preds, y).item()
probs = torch.sigmoid(preds)
pred_mask = (probs > 0.5).float()
inter = (pred_mask * y).sum()
union = pred_mask.sum() + y.sum() - inter
iou = (inter + 1e-6) / (union + 1e-6)
val_iou_sum += iou.item()
# Stats
avg_t = train_loss / len(loaders['train'])
avg_v = val_loss / len(loaders['val'])
avg_iou = val_iou_sum / len(loaders['val'])
# Append to CSV
with open(LOG_PATH, 'a', newline='') as f:
writer = csv.writer(f)
if ep == 0 and not os.path.exists(LOG_PATH):
writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
writer.writerow([ep+1, avg_t, avg_v, avg_iou])
# Status & Logic
status = ""
if swa_active:
swa_model.update_parameters(model)
swa_scheduler.step()
swa_epoch_counter += 1
status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"
if swa_epoch_counter >= SWA_DURATION:
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
print("SWA Complete. Saving Final Model.")
update_bn(loaders['train'], swa_model, device=device)
torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
break
else:
scheduler.step()
if avg_iou > best_iou:
best_iou = avg_iou
patience_counter = 0
torch.save(model.state_dict(), BEST_MODEL_PATH)
status = f" Best IoU!"
else:
patience_counter += 1
status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"
if patience_counter >= PATIENCE_TRIGGER:
print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
swa_active = True
swa_epoch_counter = 0
model.load_state_dict(torch.load(BEST_MODEL_PATH))
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
# Save Checkpoint
torch.save({
'epoch': ep, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'swa_state_dict': swa_model.state_dict() if swa_active else None,
'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
'best_iou': best_iou, 'patience_counter': patience_counter,
'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
}, CHECKPOINT_PATH)
except KeyboardInterrupt:
print("Training Interrupted. Checkpoint Saved.")
# Plot Results
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
if len(df) > 0:
plt.figure(figsize=(10, 5))
plt.plot(df['train_loss'], label='Train Loss')
plt.plot(df['val_loss'], label='Val Loss')
plt.legend()
plt.title("Training Curves")
plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
plt.show()
Initializing Datasets... Loaders Ready. Device: cuda Starting Training. Max: 5000 Eps. Batch: 2 (Accum to 16)
/tmp/ipython-input-4233285996.py:77: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler()
Ep 1: 0%| | 0/5 [00:00<?, ?it/s]
/tmp/ipython-input-4233285996.py:133: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(enabled=True):
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /tmp/ipython-input-4233285996.py in <cell line: 0>() 137 138 # Backward --> 139 scaler.scale(loss).backward() 140 141 # Update weights every ACCUM_STEPS /usr/local/lib/python3.12/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs) 623 inputs=inputs, 624 ) --> 625 torch.autograd.backward( 626 self, gradient, retain_graph, create_graph, inputs=inputs 627 ) /usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 352 # some Python versions print out the first line of a multi-line function 353 # calls in the traceback and some print out the last line --> 354 _engine_run_backward( 355 tensors, 356 grad_tensors_, /usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs) 839 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) 840 try: --> 841 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 842 t_outputs, *args, **kwargs 843 ) # Calls into the C++ engine to run the backward pass RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
In [10]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc
# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
# --- ULTRA-LOW MEMORY SETTINGS ---
PHYSICAL_BATCH_SIZE = 1 # Process 1 image at a time
ACCUM_STEPS = 16 # Update weights every 16 images (Effective Batch = 16)
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50
# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
# --- 1. HELPERS & DATASET ---
def compute_metrics(pred_probs, targets, threshold=0.5):
"""Calculates IoU and Dice Score"""
pred_mask = (pred_probs > threshold).float()
intersection = (pred_mask * targets).sum()
union = pred_mask.sum() + targets.sum() - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
return iou.item(), dice.item()
class SatMAEDataset(Dataset):
def __init__(self, x_path, y_path):
self.x_data = np.load(x_path, mmap_mode='r')
self.y_data = np.load(y_path, mmap_mode='r')
def __len__(self): return len(self.x_data)
def __getitem__(self, idx):
x_np = self.x_data[idx]
y_np = self.y_data[idx]
x_float = x_np.astype(np.float32)
y_float = y_np.astype(np.float32)
x_float[x_np == 255] = 0.0 # Handle NoData
y_float[y_np == 255] = 0.0
x_float /= 250.0 # Scale 0-250 -> 0-1
return torch.from_numpy(x_float), torch.from_numpy(y_float)
# --- 2. SETUP LOADERS ---
try:
print("Initializing Datasets...")
train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
loaders = {
'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True),
'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
}
except Exception as e:
raise RuntimeError(f" Error loading data. Did you run Cell 1? ({e})")
# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Aggressive Cleanup
torch.cuda.empty_cache()
gc.collect()
try:
# Ensure Cell 2 was run
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
criterion = CompoundLoss()
except NameError:
raise NameError(" Error: Model classes not found. Please RUN CELL 2 (Architecture) first.")
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler() # Mixed Precision Scaler
# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
# --- 4. RESUME LOGIC ---
start_epoch = 0; best_iou = 0.0; patience_counter = 0; swa_active = False; swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}
if os.path.exists(CHECKPOINT_PATH):
print(" Resuming from Checkpoint...")
try:
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_iou = ckpt['best_iou']
patience_counter = ckpt['patience_counter']
swa_active = ckpt['swa_active']
swa_epoch_counter = ckpt['swa_epoch_counter']
if swa_active:
swa_model.load_state_dict(ckpt['swa_state_dict'])
swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
history['epoch'] = df['epoch'].tolist()
history['train_loss'] = df['train_loss'].tolist()
history['val_loss'] = df['val_loss'].tolist()
history['val_iou'] = df['val_iou'].tolist()
print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
except Exception as e:
print(f" Checkpoint load failed ({e}). Starting fresh.")
# --- 5. TRAINING LOOP (The Engine) ---
print(f" Training Start. Batch: {PHYSICAL_BATCH_SIZE} | Accum: {ACCUM_STEPS}")
print("-" * 50)
try:
for ep in range(start_epoch, MAX_EPOCHS):
model.train()
train_loss = 0
optimizer.zero_grad(set_to_none=True)
pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
for i, (x, y) in enumerate(pbar):
x, y = x.to(device), y.to(device)
# Forward (Mixed Precision)
with autocast(enabled=True):
preds = model(x)
loss = criterion(preds, y)
loss = loss / ACCUM_STEPS # Normalize loss
# Backward
scaler.scale(loss).backward()
# Step (Gradient Accumulation)
if (i + 1) % ACCUM_STEPS == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
train_loss += loss.item() * ACCUM_STEPS
# Cleanup before Validation
del x, y, preds, loss
torch.cuda.empty_cache()
# Validation
eval_model = swa_model if swa_active else model
eval_model.eval()
val_loss = 0
val_iou_sum = 0
with torch.no_grad():
for x, y in loaders['val']:
x, y = x.to(device), y.to(device)
with autocast(enabled=True):
if swa_active: preds = eval_model(x)
else: preds = model(x)
val_loss += criterion(preds, y).item()
probs = torch.sigmoid(preds)
iou, _ = compute_metrics(probs, y)
val_iou_sum += iou
# Stats
avg_t = train_loss / len(loaders['train'])
avg_v = val_loss / len(loaders['val'])
avg_iou = val_iou_sum / len(loaders['val'])
# Logging
with open(LOG_PATH, 'a', newline='') as f:
writer = csv.writer(f)
if ep == 0 and not os.path.exists(LOG_PATH):
writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
writer.writerow([ep+1, avg_t, avg_v, avg_iou])
# SWA & Patience Logic [Visual Flowchart Reference below]
#
status = ""
if swa_active:
swa_model.update_parameters(model)
swa_scheduler.step()
swa_epoch_counter += 1
status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"
if swa_epoch_counter >= SWA_DURATION:
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
print(" SWA Complete. Saving Final Model.")
update_bn(loaders['train'], swa_model, device=device)
torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
break
else:
scheduler.step()
if avg_iou > best_iou:
best_iou = avg_iou
patience_counter = 0
torch.save(model.state_dict(), BEST_MODEL_PATH)
status = f" Best IoU!"
else:
patience_counter += 1
status = f"Wait ({patience_counter})"
if patience_counter >= PATIENCE_TRIGGER:
print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
swa_active = True
swa_epoch_counter = 0
# Reset to best model before starting SWA
model.load_state_dict(torch.load(BEST_MODEL_PATH))
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
# Rolling Checkpoint
torch.save({
'epoch': ep, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'swa_state_dict': swa_model.state_dict() if swa_active else None,
'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
'best_iou': best_iou, 'patience_counter': patience_counter,
'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
}, CHECKPOINT_PATH)
except KeyboardInterrupt:
print(" Training Interrupted. Checkpoint Saved.")
# Plot Results
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
if len(df) > 0:
plt.figure(figsize=(10, 5))
plt.plot(df['train_loss'], label='Train Loss')
plt.plot(df['val_loss'], label='Val Loss')
plt.legend()
plt.title("Training Curves")
plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
plt.show()
Initializing Datasets... Device: cuda Training Start. Batch: 1 | Accum: 16 --------------------------------------------------
/tmp/ipython-input-3170614832.py:84: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler() # Mixed Precision Scaler
Ep 1: 0%| | 0/9 [00:00<?, ?it/s]
/tmp/ipython-input-3170614832.py:136: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(enabled=True):
/tmp/ipython-input-3170614832.py:166: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(enabled=True):
Ep 1 | T: 1.4361 | V: 0.9212 | IoU: 0.0000 | 🏆 Best IoU!
Ep 2: 0%| | 0/9 [00:00<?, ?it/s]
Ep 2 | T: 1.4359 | V: 0.9345 | IoU: 0.0245 | 🏆 Best IoU!
Ep 3: 0%| | 0/9 [00:00<?, ?it/s]
Ep 3 | T: 1.4360 | V: 0.9901 | IoU: 0.3243 | 🏆 Best IoU!
Ep 4: 0%| | 0/9 [00:00<?, ?it/s]
Ep 4 | T: 1.4360 | V: 1.1157 | IoU: 0.3136 | Wait (1)
Ep 5: 0%| | 0/9 [00:00<?, ?it/s]
Ep 5 | T: 1.4363 | V: 1.2833 | IoU: 0.2935 | Wait (2)
Ep 6: 0%| | 0/9 [00:00<?, ?it/s]
Ep 6 | T: 1.4359 | V: 1.4157 | IoU: 0.2826 | Wait (3)
Ep 7: 0%| | 0/9 [00:00<?, ?it/s]
Ep 7 | T: 1.4355 | V: 1.4877 | IoU: 0.2782 | Wait (4)
Ep 8: 0%| | 0/9 [00:00<?, ?it/s]
Ep 8 | T: 1.4359 | V: 1.5201 | IoU: 0.2767 | Wait (5)
Ep 9: 0%| | 0/9 [00:00<?, ?it/s]
Ep 9 | T: 1.4362 | V: 1.5329 | IoU: 0.2761 | Wait (6)
Ep 10: 0%| | 0/9 [00:00<?, ?it/s]
Ep 10 | T: 1.4365 | V: 1.5380 | IoU: 0.2762 | Wait (7)
Ep 11: 0%| | 0/9 [00:00<?, ?it/s]
Ep 11 | T: 1.4367 | V: 1.5406 | IoU: 0.2762 | Wait (8)
Ep 12: 0%| | 0/9 [00:00<?, ?it/s]
Ep 12 | T: 1.4358 | V: 1.5412 | IoU: 0.2761 | Wait (9)
Ep 13: 0%| | 0/9 [00:00<?, ?it/s]
Ep 13 | T: 1.4361 | V: 1.5414 | IoU: 0.2760 | Wait (10)
Ep 14: 0%| | 0/9 [00:00<?, ?it/s]
Ep 14 | T: 1.4366 | V: 1.5416 | IoU: 0.2760 | Wait (11)
Ep 15: 0%| | 0/9 [00:00<?, ?it/s]
Ep 15 | T: 1.4357 | V: 1.5416 | IoU: 0.2759 | Wait (12)
Ep 16: 0%| | 0/9 [00:00<?, ?it/s]
Ep 16 | T: 1.4358 | V: 1.5414 | IoU: 0.2760 | Wait (13)
Ep 17: 0%| | 0/9 [00:00<?, ?it/s]
Ep 17 | T: 1.4363 | V: 1.5413 | IoU: 0.2760 | Wait (14)
Ep 18: 0%| | 0/9 [00:00<?, ?it/s]
Ep 18 | T: 1.4356 | V: 1.5418 | IoU: 0.2761 | Wait (15)
Ep 19: 0%| | 0/9 [00:00<?, ?it/s]
Ep 19 | T: 1.4362 | V: 1.5419 | IoU: 0.2759 | Wait (16)
Ep 20: 0%| | 0/9 [00:00<?, ?it/s]
Ep 20 | T: 1.4363 | V: 1.5418 | IoU: 0.2759 | Wait (17)
Ep 21: 0%| | 0/9 [00:00<?, ?it/s]
Ep 21 | T: 1.4360 | V: 1.5415 | IoU: 0.2760 | Wait (18)
Ep 22: 0%| | 0/9 [00:00<?, ?it/s]
Ep 22 | T: 1.4365 | V: 1.5417 | IoU: 0.2761 | Wait (19)
Ep 23: 0%| | 0/9 [00:00<?, ?it/s]
Ep 23 | T: 1.4361 | V: 1.5414 | IoU: 0.2758 | Wait (20)
Ep 24: 0%| | 0/9 [00:00<?, ?it/s]
Ep 24 | T: 1.4361 | V: 1.5415 | IoU: 0.2760 | Wait (21)
Ep 25: 0%| | 0/9 [00:00<?, ?it/s]
Ep 25 | T: 1.4361 | V: 1.5415 | IoU: 0.2759 | Wait (22)
Ep 26: 0%| | 0/9 [00:00<?, ?it/s]
Ep 26 | T: 1.4359 | V: 1.5415 | IoU: 0.2761 | Wait (23)
Ep 27: 0%| | 0/9 [00:00<?, ?it/s]
Ep 27 | T: 1.4351 | V: 1.5418 | IoU: 0.2761 | Wait (24)
Ep 28: 0%| | 0/9 [00:00<?, ?it/s]
Ep 28 | T: 1.4359 | V: 1.5415 | IoU: 0.2759 | Wait (25)
Ep 29: 0%| | 0/9 [00:00<?, ?it/s]
Ep 29 | T: 1.4355 | V: 1.5417 | IoU: 0.2761 | Wait (26)
Ep 30: 0%| | 0/9 [00:00<?, ?it/s]
Ep 30 | T: 1.4361 | V: 1.5414 | IoU: 0.2759 | Wait (27)
Ep 31: 0%| | 0/9 [00:00<?, ?it/s]
Ep 31 | T: 1.4359 | V: 1.5413 | IoU: 0.2762 | Wait (28) Training Interrupted. Checkpoint Saved.
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) /usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key) 3804 try: -> 3805 return self._engine.get_loc(casted_key) 3806 except KeyError as err: index.pyx in pandas._libs.index.IndexEngine.get_loc() index.pyx in pandas._libs.index.IndexEngine.get_loc() pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item() pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item() KeyError: 'train_loss' The above exception was the direct cause of the following exception: KeyError Traceback (most recent call last) /tmp/ipython-input-3170614832.py in <cell line: 0>() 239 if len(df) > 0: 240 plt.figure(figsize=(10, 5)) --> 241 plt.plot(df['train_loss'], label='Train Loss') 242 plt.plot(df['val_loss'], label='Val Loss') 243 plt.legend() /usr/local/lib/python3.12/dist-packages/pandas/core/frame.py in __getitem__(self, key) 4100 if self.columns.nlevels > 1: 4101 return self._getitem_multilevel(key) -> 4102 indexer = self.columns.get_loc(key) 4103 if is_integer(indexer): 4104 indexer = [indexer] /usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key) 3810 ): 3811 raise InvalidIndexError(key) -> 3812 raise KeyError(key) from err 3813 except TypeError: 3814 # If we have a listlike key, _check_indexing_error will raise KeyError: 'train_loss'
<Figure size 1000x500 with 0 Axes>
In [11]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc
# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
# --- ECO-MODE SETTINGS (CRASH PREVENTION) ---
PHYSICAL_BATCH_SIZE = 1
ACCUM_STEPS = 16
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50
NUM_WORKERS = 0 # <--- CHANGED: 0 saves massive RAM (no subprocesses)
PIN_MEMORY = False # <--- CHANGED: False saves RAM (no buffer reservation)
# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
# --- 1. HELPERS & DATASET ---
def compute_metrics(pred_probs, targets, threshold=0.5):
pred_mask = (pred_probs > threshold).float()
intersection = (pred_mask * targets).sum()
union = pred_mask.sum() + targets.sum() - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
return iou.item(), dice.item()
class SatMAEDataset(Dataset):
def __init__(self, x_path, y_path):
self.x_data = np.load(x_path, mmap_mode='r')
self.y_data = np.load(y_path, mmap_mode='r')
def __len__(self): return len(self.x_data)
def __getitem__(self, idx):
x_np = self.x_data[idx]
y_np = self.y_data[idx]
x_float = x_np.astype(np.float32)
y_float = y_np.astype(np.float32)
x_float[x_np == 255] = 0.0
y_float[y_np == 255] = 0.0
x_float /= 250.0
return torch.from_numpy(x_float), torch.from_numpy(y_float)
# --- 2. SETUP LOADERS (OPTIMIZED) ---
try:
print("Initializing Datasets in ECO-MODE...")
train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
loaders = {
'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY),
'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
}
except Exception as e:
raise RuntimeError(f" Error loading data: {e}")
# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Aggressive Cleanup
torch.cuda.empty_cache()
gc.collect()
try:
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
criterion = CompoundLoss()
except NameError:
raise NameError(" Error: Model classes not found. Please RUN CELL 2 first.")
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler()
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
# --- 4. RESUME LOGIC ---
start_epoch = 0; best_iou = 0.0; patience_counter = 0; swa_active = False; swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}
if os.path.exists(CHECKPOINT_PATH):
print(" Resuming from Checkpoint...")
try:
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_iou = ckpt['best_iou']
patience_counter = ckpt['patience_counter']
swa_active = ckpt['swa_active']
swa_epoch_counter = ckpt['swa_epoch_counter']
if swa_active:
swa_model.load_state_dict(ckpt['swa_state_dict'])
swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
# Reload history safely
history['epoch'] = df['epoch'].tolist()
history['train_loss'] = df['train_loss'].tolist()
history['val_loss'] = df['val_loss'].tolist()
history['val_iou'] = df['val_iou'].tolist()
print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
except Exception as e:
print(f"Checkpoint load failed ({e}). Starting fresh.")
# --- 5. TRAINING LOOP ---
print(f" Training (Eco-Mode). Workers: {NUM_WORKERS} | Pin Mem: {PIN_MEMORY}")
try:
for ep in range(start_epoch, MAX_EPOCHS):
model.train()
train_loss = 0
optimizer.zero_grad(set_to_none=True)
pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
for i, (x, y) in enumerate(pbar):
x, y = x.to(device), y.to(device)
with autocast(enabled=True):
preds = model(x)
loss = criterion(preds, y)
loss = loss / ACCUM_STEPS
scaler.scale(loss).backward()
if (i + 1) % ACCUM_STEPS == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
train_loss += loss.item() * ACCUM_STEPS
del x, y, preds, loss
torch.cuda.empty_cache()
eval_model = swa_model if swa_active else model
eval_model.eval()
val_loss, val_iou_sum = 0, 0
with torch.no_grad():
for x, y in loaders['val']:
x, y = x.to(device), y.to(device)
with autocast(enabled=True):
if swa_active: preds = eval_model(x)
else: preds = model(x)
val_loss += criterion(preds, y).item()
probs = torch.sigmoid(preds)
iou, _ = compute_metrics(probs, y)
val_iou_sum += iou
avg_t = train_loss / len(loaders['train'])
avg_v = val_loss / len(loaders['val'])
avg_iou = val_iou_sum / len(loaders['val'])
with open(LOG_PATH, 'a', newline='') as f:
writer = csv.writer(f)
if ep == 0 and not os.path.exists(LOG_PATH):
writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
writer.writerow([ep+1, avg_t, avg_v, avg_iou])
status = ""
if swa_active:
swa_model.update_parameters(model)
swa_scheduler.step()
swa_epoch_counter += 1
status = f"SWA ({swa_epoch_counter}/{SWA_DURATION})"
if swa_epoch_counter >= SWA_DURATION:
print(f"Ep {ep+1} | IoU: {avg_iou:.4f} | SWA DONE")
update_bn(loaders['train'], swa_model, device=device)
torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
break
else:
scheduler.step()
if avg_iou > best_iou:
best_iou = avg_iou
patience_counter = 0
torch.save(model.state_dict(), BEST_MODEL_PATH)
status = f"Best!"
else:
patience_counter += 1
status = f"Wait ({patience_counter})"
if patience_counter >= PATIENCE_TRIGGER:
print(f" Triggering SWA...")
swa_active = True
swa_epoch_counter = 0
model.load_state_dict(torch.load(BEST_MODEL_PATH))
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
torch.save({
'epoch': ep, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'swa_state_dict': swa_model.state_dict() if swa_active else None,
'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
'best_iou': best_iou, 'patience_counter': patience_counter,
'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
}, CHECKPOINT_PATH)
# Force garbage collection every epoch to keep RAM low
gc.collect()
except KeyboardInterrupt:
print(" Training Interrupted. Checkpoint Saved.")
if os.path.exists(LOG_PATH):
df = pd.read_csv(LOG_PATH)
if len(df) > 0:
plt.figure(figsize=(10, 5))
plt.plot(df['train_loss'], label='Train Loss')
plt.plot(df['val_loss'], label='Val Loss')
plt.legend()
plt.title("Training Curves")
plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
plt.show()
Initializing Datasets in ECO-MODE... Device: cuda
/tmp/ipython-input-1738852815.py:86: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = GradScaler()
Resuming from Checkpoint... Checkpoint load failed (PytorchStreamReader failed locating file data/101: file not found). Starting fresh. Training (Eco-Mode). Workers: 0 | Pin Mem: False
Ep 1: 0%| | 0/9 [00:00<?, ?it/s]
/tmp/ipython-input-1738852815.py:133: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(enabled=True):
/tmp/ipython-input-1738852815.py:157: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast(enabled=True):
Ep 1 | T: 1.4697 | V: 0.9073 | IoU: 0.6619 | Best!
Ep 2: 0%| | 0/9 [00:00<?, ?it/s]
Ep 2 | T: 1.4705 | V: 0.9221 | IoU: 0.6619 | Wait (1)
Ep 3: 0%| | 0/9 [00:00<?, ?it/s]
Ep 3 | T: 1.4699 | V: 0.9855 | IoU: 0.6610 | Wait (2)
Ep 4: 0%| | 0/9 [00:00<?, ?it/s]
Ep 4 | T: 1.4708 | V: 1.1187 | IoU: 0.5708 | Wait (3)
Ep 5: 0%| | 0/9 [00:00<?, ?it/s]
Ep 5 | T: 1.4704 | V: 1.2968 | IoU: 0.4553 | Wait (4)
Ep 6: 0%| | 0/9 [00:00<?, ?it/s]
Ep 6 | T: 1.4698 | V: 1.4386 | IoU: 0.4018 | Wait (5)
Ep 7: 0%| | 0/9 [00:00<?, ?it/s]
Ep 7 | T: 1.4707 | V: 1.5153 | IoU: 0.3800 | Wait (6)
Ep 8: 0%| | 0/9 [00:00<?, ?it/s]
Ep 8 | T: 1.4707 | V: 1.5490 | IoU: 0.3717 | Wait (7)
Ep 9: 0%| | 0/9 [00:00<?, ?it/s]
Ep 9 | T: 1.4699 | V: 1.5634 | IoU: 0.3686 | Wait (8)
Ep 10: 0%| | 0/9 [00:00<?, ?it/s]
Ep 10 | T: 1.4711 | V: 1.5688 | IoU: 0.3673 | Wait (9)
Ep 11: 0%| | 0/9 [00:00<?, ?it/s]
Ep 11 | T: 1.4708 | V: 1.5707 | IoU: 0.3668 | Wait (10)
Ep 12: 0%| | 0/9 [00:00<?, ?it/s]
Ep 12 | T: 1.4709 | V: 1.5720 | IoU: 0.3667 | Wait (11)
Ep 13: 0%| | 0/9 [00:00<?, ?it/s]
Ep 13 | T: 1.4703 | V: 1.5714 | IoU: 0.3667 | Wait (12)
Ep 14: 0%| | 0/9 [00:00<?, ?it/s]
Ep 14 | T: 1.4704 | V: 1.5715 | IoU: 0.3668 | Wait (13)
Ep 15: 0%| | 0/9 [00:00<?, ?it/s]
Ep 15 | T: 1.4700 | V: 1.5719 | IoU: 0.3667 | Wait (14)
Ep 16: 0%| | 0/9 [00:00<?, ?it/s]
Ep 16 | T: 1.4701 | V: 1.5719 | IoU: 0.3668 | Wait (15)
Ep 17: 0%| | 0/9 [00:00<?, ?it/s]
Ep 17 | T: 1.4699 | V: 1.5720 | IoU: 0.3668 | Wait (16)
Ep 18: 0%| | 0/9 [00:00<?, ?it/s]
Ep 18 | T: 1.4705 | V: 1.5723 | IoU: 0.3667 | Wait (17)
Ep 19: 0%| | 0/9 [00:00<?, ?it/s]
Ep 19 | T: 1.4704 | V: 1.5726 | IoU: 0.3664 | Wait (18)
Ep 20: 0%| | 0/9 [00:00<?, ?it/s]
Ep 20 | T: 1.4700 | V: 1.5727 | IoU: 0.3667 | Wait (19)
Ep 21: 0%| | 0/9 [00:00<?, ?it/s]
Ep 21 | T: 1.4705 | V: 1.5725 | IoU: 0.3666 | Wait (20)
Ep 22: 0%| | 0/9 [00:00<?, ?it/s]
Ep 22 | T: 1.4707 | V: 1.5724 | IoU: 0.3665 | Wait (21)
Ep 23: 0%| | 0/9 [00:00<?, ?it/s]
Ep 23 | T: 1.4704 | V: 1.5722 | IoU: 0.3665 | Wait (22)
Ep 24: 0%| | 0/9 [00:00<?, ?it/s]
Ep 24 | T: 1.4706 | V: 1.5724 | IoU: 0.3665 | Wait (23)
Ep 25: 0%| | 0/9 [00:00<?, ?it/s]
Ep 25 | T: 1.4701 | V: 1.5727 | IoU: 0.3663 | Wait (24)
Ep 26: 0%| | 0/9 [00:00<?, ?it/s]
Ep 26 | T: 1.4708 | V: 1.5726 | IoU: 0.3666 | Wait (25)
Ep 27: 0%| | 0/9 [00:00<?, ?it/s]
Ep 27 | T: 1.4699 | V: 1.5723 | IoU: 0.3668 | Wait (26)
Ep 28: 0%| | 0/9 [00:00<?, ?it/s]
Ep 28 | T: 1.4702 | V: 1.5723 | IoU: 0.3666 | Wait (27)
Ep 29: 0%| | 0/9 [00:00<?, ?it/s]
Ep 29 | T: 1.4702 | V: 1.5718 | IoU: 0.3667 | Wait (28)
Ep 30: 0%| | 0/9 [00:00<?, ?it/s]
Ep 30 | T: 1.4707 | V: 1.5723 | IoU: 0.3665 | Wait (29)
Ep 31: 0%| | 0/9 [00:00<?, ?it/s]
Ep 31 | T: 1.4701 | V: 1.5722 | IoU: 0.3665 | Wait (30)
Ep 32: 0%| | 0/9 [00:00<?, ?it/s]
Ep 32 | T: 1.4705 | V: 1.5721 | IoU: 0.3666 | Wait (31)
Ep 33: 0%| | 0/9 [00:00<?, ?it/s]
Ep 33 | T: 1.4706 | V: 1.5720 | IoU: 0.3667 | Wait (32)
Ep 34: 0%| | 0/9 [00:00<?, ?it/s]
Ep 34 | T: 1.4712 | V: 1.5723 | IoU: 0.3664 | Wait (33)
Ep 35: 0%| | 0/9 [00:00<?, ?it/s]
Ep 35 | T: 1.4703 | V: 1.5722 | IoU: 0.3666 | Wait (34)
Ep 36: 0%| | 0/9 [00:00<?, ?it/s]
Ep 36 | T: 1.4709 | V: 1.5722 | IoU: 0.3664 | Wait (35)
Ep 37: 0%| | 0/9 [00:00<?, ?it/s]
Ep 37 | T: 1.4703 | V: 1.5723 | IoU: 0.3666 | Wait (36)
Ep 38: 0%| | 0/9 [00:00<?, ?it/s]
Ep 38 | T: 1.4707 | V: 1.5727 | IoU: 0.3662 | Wait (37)
Ep 39: 0%| | 0/9 [00:00<?, ?it/s]
Ep 39 | T: 1.4701 | V: 1.5724 | IoU: 0.3667 | Wait (38)
Ep 40: 0%| | 0/9 [00:00<?, ?it/s]
Ep 40 | T: 1.4709 | V: 1.5723 | IoU: 0.3664 | Wait (39)
Ep 41: 0%| | 0/9 [00:00<?, ?it/s]
Ep 41 | T: 1.4707 | V: 1.5722 | IoU: 0.3663 | Wait (40)
Ep 42: 0%| | 0/9 [00:00<?, ?it/s]
Ep 42 | T: 1.4706 | V: 1.5723 | IoU: 0.3660 | Wait (41)
Ep 43: 0%| | 0/9 [00:00<?, ?it/s]
Ep 43 | T: 1.4708 | V: 1.5723 | IoU: 0.3665 | Wait (42)
Ep 44: 0%| | 0/9 [00:00<?, ?it/s]
Ep 44 | T: 1.4707 | V: 1.5721 | IoU: 0.3669 | Wait (43)
Ep 45: 0%| | 0/9 [00:00<?, ?it/s]
Ep 45 | T: 1.4702 | V: 1.5723 | IoU: 0.3668 | Wait (44)
Ep 46: 0%| | 0/9 [00:00<?, ?it/s]
Ep 46 | T: 1.4699 | V: 1.5722 | IoU: 0.3669 | Wait (45)
Ep 47: 0%| | 0/9 [00:00<?, ?it/s]
Ep 47 | T: 1.4697 | V: 1.5723 | IoU: 0.3666 | Wait (46)
Ep 48: 0%| | 0/9 [00:00<?, ?it/s]
Ep 48 | T: 1.4707 | V: 1.5726 | IoU: 0.3663 | Wait (47)
Ep 49: 0%| | 0/9 [00:00<?, ?it/s]
Ep 49 | T: 1.4708 | V: 1.5724 | IoU: 0.3664 | Wait (48)
Ep 50: 0%| | 0/9 [00:00<?, ?it/s]
Ep 50 | T: 1.4703 | V: 1.5720 | IoU: 0.3665 | Wait (49)
Ep 51: 0%| | 0/9 [00:00<?, ?it/s]
Ep 51 | T: 1.4704 | V: 1.5722 | IoU: 0.3666 | Wait (50)
Ep 52: 0%| | 0/9 [00:00<?, ?it/s]
Ep 52 | T: 1.4709 | V: 1.5723 | IoU: 0.3664 | Wait (51)
Ep 53: 0%| | 0/9 [00:00<?, ?it/s]
Ep 53 | T: 1.4700 | V: 1.5721 | IoU: 0.3666 | Wait (52)
Ep 54: 0%| | 0/9 [00:00<?, ?it/s]
Ep 54 | T: 1.4702 | V: 1.5725 | IoU: 0.3667 | Wait (53)
Ep 55: 0%| | 0/9 [00:00<?, ?it/s]
Ep 55 | T: 1.4706 | V: 1.5729 | IoU: 0.3663 | Wait (54)
Ep 56: 0%| | 0/9 [00:00<?, ?it/s]
Ep 56 | T: 1.4711 | V: 1.5727 | IoU: 0.3664 | Wait (55)
Ep 57: 0%| | 0/9 [00:00<?, ?it/s]
Ep 57 | T: 1.4706 | V: 1.5724 | IoU: 0.3666 | Wait (56)
Ep 58: 0%| | 0/9 [00:00<?, ?it/s]
Ep 58 | T: 1.4705 | V: 1.5724 | IoU: 0.3665 | Wait (57)
Ep 59: 0%| | 0/9 [00:00<?, ?it/s]
Ep 59 | T: 1.4699 | V: 1.5725 | IoU: 0.3666 | Wait (58)
Ep 60: 0%| | 0/9 [00:00<?, ?it/s]
Ep 60 | T: 1.4704 | V: 1.5727 | IoU: 0.3666 | Wait (59)
Ep 61: 0%| | 0/9 [00:00<?, ?it/s]
Ep 61 | T: 1.4706 | V: 1.5727 | IoU: 0.3665 | Wait (60)
Ep 62: 0%| | 0/9 [00:00<?, ?it/s]
Ep 62 | T: 1.4705 | V: 1.5724 | IoU: 0.3665 | Wait (61)
Ep 63: 0%| | 0/9 [00:00<?, ?it/s]
Ep 63 | T: 1.4701 | V: 1.5721 | IoU: 0.3667 | Wait (62)
Ep 64: 0%| | 0/9 [00:00<?, ?it/s]
Ep 64 | T: 1.4701 | V: 1.5722 | IoU: 0.3664 | Wait (63)
Ep 65: 0%| | 0/9 [00:00<?, ?it/s]
Ep 65 | T: 1.4694 | V: 1.5718 | IoU: 0.3664 | Wait (64)
Ep 66: 0%| | 0/9 [00:00<?, ?it/s]
Ep 66 | T: 1.4703 | V: 1.5714 | IoU: 0.3668 | Wait (65)
Ep 67: 0%| | 0/9 [00:00<?, ?it/s]
Ep 67 | T: 1.4703 | V: 1.5717 | IoU: 0.3667 | Wait (66)
Ep 68: 0%| | 0/9 [00:00<?, ?it/s]
Ep 68 | T: 1.4699 | V: 1.5720 | IoU: 0.3669 | Wait (67)
Ep 69: 0%| | 0/9 [00:00<?, ?it/s]
Ep 69 | T: 1.4697 | V: 1.5722 | IoU: 0.3671 | Wait (68)
Ep 70: 0%| | 0/9 [00:00<?, ?it/s]
Ep 70 | T: 1.4698 | V: 1.5720 | IoU: 0.3666 | Wait (69)
Ep 71: 0%| | 0/9 [00:00<?, ?it/s]
Ep 71 | T: 1.4702 | V: 1.5722 | IoU: 0.3666 | Wait (70)
Ep 72: 0%| | 0/9 [00:00<?, ?it/s]
Ep 72 | T: 1.4697 | V: 1.5723 | IoU: 0.3667 | Wait (71)
Ep 73: 0%| | 0/9 [00:00<?, ?it/s]
Ep 73 | T: 1.4707 | V: 1.5723 | IoU: 0.3667 | Wait (72)
Ep 74: 0%| | 0/9 [00:00<?, ?it/s]
Ep 74 | T: 1.4695 | V: 1.5716 | IoU: 0.3668 | Wait (73)
Ep 75: 0%| | 0/9 [00:00<?, ?it/s]
Ep 75 | T: 1.4704 | V: 1.5723 | IoU: 0.3666 | Wait (74)
Ep 76: 0%| | 0/9 [00:00<?, ?it/s]
Ep 76 | T: 1.4703 | V: 1.5724 | IoU: 0.3667 | Wait (75)
Ep 77: 0%| | 0/9 [00:00<?, ?it/s]
Ep 77 | T: 1.4696 | V: 1.5724 | IoU: 0.3667 | Wait (76)
Ep 78: 0%| | 0/9 [00:00<?, ?it/s]
Ep 78 | T: 1.4712 | V: 1.5728 | IoU: 0.3666 | Wait (77)
Ep 79: 0%| | 0/9 [00:00<?, ?it/s]
Ep 79 | T: 1.4707 | V: 1.5722 | IoU: 0.3666 | Wait (78)
Ep 80: 0%| | 0/9 [00:00<?, ?it/s]
Ep 80 | T: 1.4709 | V: 1.5729 | IoU: 0.3662 | Wait (79)
Ep 81: 0%| | 0/9 [00:00<?, ?it/s]
Ep 81 | T: 1.4699 | V: 1.5724 | IoU: 0.3666 | Wait (80)
Ep 82: 0%| | 0/9 [00:00<?, ?it/s]
Ep 82 | T: 1.4708 | V: 1.5727 | IoU: 0.3663 | Wait (81)
Ep 83: 0%| | 0/9 [00:00<?, ?it/s]
Ep 83 | T: 1.4700 | V: 1.5724 | IoU: 0.3665 | Wait (82)
Ep 84: 0%| | 0/9 [00:00<?, ?it/s]
Ep 84 | T: 1.4711 | V: 1.5723 | IoU: 0.3666 | Wait (83)
Ep 85: 0%| | 0/9 [00:00<?, ?it/s]
Ep 85 | T: 1.4698 | V: 1.5723 | IoU: 0.3669 | Wait (84)
Ep 86: 0%| | 0/9 [00:00<?, ?it/s]
Ep 86 | T: 1.4704 | V: 1.5715 | IoU: 0.3668 | Wait (85)
Ep 87: 0%| | 0/9 [00:00<?, ?it/s]
Ep 87 | T: 1.4701 | V: 1.5721 | IoU: 0.3665 | Wait (86)
Ep 88: 0%| | 0/9 [00:00<?, ?it/s]
Ep 88 | T: 1.4699 | V: 1.5721 | IoU: 0.3667 | Wait (87)
Ep 89: 0%| | 0/9 [00:00<?, ?it/s]
Ep 89 | T: 1.4700 | V: 1.5723 | IoU: 0.3667 | Wait (88)
Ep 90: 0%| | 0/9 [00:00<?, ?it/s]
Ep 90 | T: 1.4698 | V: 1.5723 | IoU: 0.3667 | Wait (89)
Ep 91: 0%| | 0/9 [00:00<?, ?it/s]
Ep 91 | T: 1.4711 | V: 1.5723 | IoU: 0.3669 | Wait (90)
Ep 92: 0%| | 0/9 [00:00<?, ?it/s]
Ep 92 | T: 1.4706 | V: 1.5721 | IoU: 0.3667 | Wait (91)
Ep 93: 0%| | 0/9 [00:00<?, ?it/s]
Ep 93 | T: 1.4706 | V: 1.5721 | IoU: 0.3666 | Wait (92)
Ep 94: 0%| | 0/9 [00:00<?, ?it/s]
Ep 94 | T: 1.4704 | V: 1.5721 | IoU: 0.3666 | Wait (93)
Ep 95: 0%| | 0/9 [00:00<?, ?it/s]
Ep 95 | T: 1.4701 | V: 1.5722 | IoU: 0.3664 | Wait (94)
Ep 96: 0%| | 0/9 [00:00<?, ?it/s]
Ep 96 | T: 1.4700 | V: 1.5719 | IoU: 0.3663 | Wait (95)
Ep 97: 0%| | 0/9 [00:00<?, ?it/s]
Ep 97 | T: 1.4705 | V: 1.5724 | IoU: 0.3664 | Wait (96)
Ep 98: 0%| | 0/9 [00:00<?, ?it/s]
Ep 98 | T: 1.4709 | V: 1.5722 | IoU: 0.3665 | Wait (97)
Ep 99: 0%| | 0/9 [00:00<?, ?it/s]
Ep 99 | T: 1.4707 | V: 1.5728 | IoU: 0.3666 | Wait (98)
Ep 100: 0%| | 0/9 [00:00<?, ?it/s]
Ep 100 | T: 1.4697 | V: 1.5724 | IoU: 0.3665 | Wait (99)
Ep 101: 0%| | 0/9 [00:00<?, ?it/s]
Triggering SWA... Ep 101 | T: 1.4703 | V: 1.5719 | IoU: 0.3665 | Wait (100)
Ep 102: 0%| | 0/9 [00:00<?, ?it/s]
/usr/local/lib/python3.12/dist-packages/torch/optim/lr_scheduler.py:192: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate warnings.warn(
Ep 102 | T: 1.4707 | V: 0.9056 | IoU: 0.6619 | SWA (1/50)
Ep 103: 0%| | 0/9 [00:00<?, ?it/s]
Ep 103 | T: 1.4703 | V: 0.9221 | IoU: 0.6619 | SWA (2/50)
Ep 104: 0%| | 0/9 [00:00<?, ?it/s]
Ep 104 | T: 1.4706 | V: 0.9854 | IoU: 0.6610 | SWA (3/50)
Ep 105: 0%| | 0/9 [00:00<?, ?it/s]
Ep 105 | T: 1.4702 | V: 1.1186 | IoU: 0.5708 | SWA (4/50)
Ep 106: 0%| | 0/9 [00:00<?, ?it/s]
Ep 106 | T: 1.4705 | V: 1.2967 | IoU: 0.4554 | SWA (5/50)
Ep 107: 0%| | 0/9 [00:00<?, ?it/s]
Ep 107 | T: 1.4703 | V: 1.4384 | IoU: 0.4018 | SWA (6/50)
Ep 108: 0%| | 0/9 [00:00<?, ?it/s]
Ep 108 | T: 1.4704 | V: 1.5144 | IoU: 0.3801 | SWA (7/50)
Ep 109: 0%| | 0/9 [00:00<?, ?it/s]
Ep 109 | T: 1.4701 | V: 1.5490 | IoU: 0.3717 | SWA (8/50)
Ep 110: 0%| | 0/9 [00:00<?, ?it/s]
Ep 110 | T: 1.4706 | V: 1.5634 | IoU: 0.3685 | SWA (9/50)
Ep 111: 0%| | 0/9 [00:00<?, ?it/s]
Ep 111 | T: 1.4710 | V: 1.5687 | IoU: 0.3677 | SWA (10/50)
Ep 112: 0%| | 0/9 [00:00<?, ?it/s]
Ep 112 | T: 1.4701 | V: 1.5711 | IoU: 0.3671 | SWA (11/50)
Ep 113: 0%| | 0/9 [00:00<?, ?it/s]
Ep 113 | T: 1.4702 | V: 1.5713 | IoU: 0.3667 | SWA (12/50)
Ep 114: 0%| | 0/9 [00:00<?, ?it/s]
Ep 114 | T: 1.4702 | V: 1.5717 | IoU: 0.3665 | SWA (13/50)
Ep 115: 0%| | 0/9 [00:00<?, ?it/s]
Ep 115 | T: 1.4705 | V: 1.5720 | IoU: 0.3666 | SWA (14/50)
Ep 116: 0%| | 0/9 [00:00<?, ?it/s]
Ep 116 | T: 1.4701 | V: 1.5727 | IoU: 0.3663 | SWA (15/50)
Ep 117: 0%| | 0/9 [00:00<?, ?it/s]
Ep 117 | T: 1.4702 | V: 1.5726 | IoU: 0.3665 | SWA (16/50)
Ep 118: 0%| | 0/9 [00:00<?, ?it/s]
Ep 118 | T: 1.4708 | V: 1.5725 | IoU: 0.3666 | SWA (17/50)
Ep 119: 0%| | 0/9 [00:00<?, ?it/s]
Ep 119 | T: 1.4708 | V: 1.5724 | IoU: 0.3665 | SWA (18/50)
Ep 120: 0%| | 0/9 [00:00<?, ?it/s]
Ep 120 | T: 1.4698 | V: 1.5725 | IoU: 0.3667 | SWA (19/50)
Ep 121: 0%| | 0/9 [00:00<?, ?it/s]
Ep 121 | T: 1.4703 | V: 1.5726 | IoU: 0.3668 | SWA (20/50)
Ep 122: 0%| | 0/9 [00:00<?, ?it/s]
Ep 122 | T: 1.4708 | V: 1.5723 | IoU: 0.3665 | SWA (21/50)
Ep 123: 0%| | 0/9 [00:00<?, ?it/s]
Ep 123 | T: 1.4707 | V: 1.5725 | IoU: 0.3664 | SWA (22/50)
Ep 124: 0%| | 0/9 [00:00<?, ?it/s]
Ep 124 | T: 1.4710 | V: 1.5726 | IoU: 0.3665 | SWA (23/50)
Ep 125: 0%| | 0/9 [00:00<?, ?it/s]
Ep 125 | T: 1.4703 | V: 1.5728 | IoU: 0.3663 | SWA (24/50)
Ep 126: 0%| | 0/9 [00:00<?, ?it/s]
Ep 126 | T: 1.4699 | V: 1.5722 | IoU: 0.3664 | SWA (25/50)
Ep 127: 0%| | 0/9 [00:00<?, ?it/s]
Ep 127 | T: 1.4708 | V: 1.5720 | IoU: 0.3663 | SWA (26/50)
Ep 128: 0%| | 0/9 [00:00<?, ?it/s]
Ep 128 | T: 1.4702 | V: 1.5723 | IoU: 0.3665 | SWA (27/50)
Ep 129: 0%| | 0/9 [00:00<?, ?it/s]
Ep 129 | T: 1.4701 | V: 1.5724 | IoU: 0.3666 | SWA (28/50)
Ep 130: 0%| | 0/9 [00:00<?, ?it/s]
Ep 130 | T: 1.4703 | V: 1.5725 | IoU: 0.3666 | SWA (29/50)
Ep 131: 0%| | 0/9 [00:00<?, ?it/s]
Ep 131 | T: 1.4709 | V: 1.5725 | IoU: 0.3667 | SWA (30/50)
Ep 132: 0%| | 0/9 [00:00<?, ?it/s]
Ep 132 | T: 1.4706 | V: 1.5726 | IoU: 0.3666 | SWA (31/50)
Ep 133: 0%| | 0/9 [00:00<?, ?it/s]
Ep 133 | T: 1.4706 | V: 1.5722 | IoU: 0.3668 | SWA (32/50)
Ep 134: 0%| | 0/9 [00:00<?, ?it/s]
Ep 134 | T: 1.4698 | V: 1.5720 | IoU: 0.3667 | SWA (33/50)
Ep 135: 0%| | 0/9 [00:00<?, ?it/s]
Ep 135 | T: 1.4710 | V: 1.5723 | IoU: 0.3667 | SWA (34/50)
Ep 136: 0%| | 0/9 [00:00<?, ?it/s]
Ep 136 | T: 1.4711 | V: 1.5722 | IoU: 0.3665 | SWA (35/50)
Ep 137: 0%| | 0/9 [00:00<?, ?it/s]
Ep 137 | T: 1.4702 | V: 1.5727 | IoU: 0.3664 | SWA (36/50)
Ep 138: 0%| | 0/9 [00:00<?, ?it/s]
Ep 138 | T: 1.4708 | V: 1.5725 | IoU: 0.3668 | SWA (37/50)
Ep 139: 0%| | 0/9 [00:00<?, ?it/s]
Ep 139 | T: 1.4707 | V: 1.5724 | IoU: 0.3667 | SWA (38/50)
Ep 140: 0%| | 0/9 [00:00<?, ?it/s]
Ep 140 | T: 1.4703 | V: 1.5724 | IoU: 0.3666 | SWA (39/50)
Ep 141: 0%| | 0/9 [00:00<?, ?it/s]
Ep 141 | T: 1.4715 | V: 1.5723 | IoU: 0.3664 | SWA (40/50)
Ep 142: 0%| | 0/9 [00:00<?, ?it/s]
Ep 142 | T: 1.4705 | V: 1.5723 | IoU: 0.3666 | SWA (41/50)
Ep 143: 0%| | 0/9 [00:00<?, ?it/s]
Ep 143 | T: 1.4708 | V: 1.5724 | IoU: 0.3667 | SWA (42/50)
Ep 144: 0%| | 0/9 [00:00<?, ?it/s]
Ep 144 | T: 1.4701 | V: 1.5724 | IoU: 0.3665 | SWA (43/50)
Ep 145: 0%| | 0/9 [00:00<?, ?it/s]
Ep 145 | T: 1.4709 | V: 1.5724 | IoU: 0.3666 | SWA (44/50)
Ep 146: 0%| | 0/9 [00:00<?, ?it/s]
Ep 146 | T: 1.4711 | V: 1.5726 | IoU: 0.3666 | SWA (45/50)
Ep 147: 0%| | 0/9 [00:00<?, ?it/s]
Ep 147 | T: 1.4697 | V: 1.5723 | IoU: 0.3665 | SWA (46/50)
Ep 148: 0%| | 0/9 [00:00<?, ?it/s]
Ep 148 | T: 1.4704 | V: 1.5723 | IoU: 0.3667 | SWA (47/50)
Ep 149: 0%| | 0/9 [00:00<?, ?it/s]
Ep 149 | T: 1.4697 | V: 1.5725 | IoU: 0.3666 | SWA (48/50)
Ep 150: 0%| | 0/9 [00:00<?, ?it/s]
Ep 150 | T: 1.4703 | V: 1.5725 | IoU: 0.3668 | SWA (49/50)
Ep 151: 0%| | 0/9 [00:00<?, ?it/s]
Ep 151 | IoU: 0.3666 | ✅ SWA DONE
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) /usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key) 3804 try: -> 3805 return self._engine.get_loc(casted_key) 3806 except KeyError as err: index.pyx in pandas._libs.index.IndexEngine.get_loc() index.pyx in pandas._libs.index.IndexEngine.get_loc() pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item() pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item() KeyError: 'train_loss' The above exception was the direct cause of the following exception: KeyError Traceback (most recent call last) /tmp/ipython-input-1738852815.py in <cell line: 0>() 223 if len(df) > 0: 224 plt.figure(figsize=(10, 5)) --> 225 plt.plot(df['train_loss'], label='Train Loss') 226 plt.plot(df['val_loss'], label='Val Loss') 227 plt.legend() /usr/local/lib/python3.12/dist-packages/pandas/core/frame.py in __getitem__(self, key) 4100 if self.columns.nlevels > 1: 4101 return self._getitem_multilevel(key) -> 4102 indexer = self.columns.get_loc(key) 4103 if is_integer(indexer): 4104 indexer = [indexer] /usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key) 3810 ): 3811 raise InvalidIndexError(key) -> 3812 raise KeyError(key) from err 3813 except TypeError: 3814 # If we have a listlike key, _check_indexing_error will raise KeyError: 'train_loss'
<Figure size 1000x500 with 0 Axes>
In [12]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader
# --- CONFIG ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
MODEL_PATH = os.path.join(SAVE_DIR, "final_swa_model.pth") # Using the SWA model
# --- 1. LOAD MODEL ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
try:
# Re-initialize model architecture
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
# Load Weights
print(f"Loading weights from: {os.path.basename(MODEL_PATH)}")
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print(" Model Loaded Successfully!")
except Exception as e:
print(f" Error loading model: {e}")
print("Make sure Cell 2 (Architecture) was run.")
# --- 2. LOAD ONE BATCH OF VAL DATA ---
# We use the existing 'val_ds' from previous cells
try:
val_loader = DataLoader(val_ds, batch_size=1, shuffle=True)
images, masks = next(iter(val_loader)) # Get 1 random sample
images, masks = images.to(device), masks.to(device)
except NameError:
print(" Error: 'val_ds' not found. Please re-run Cell 4 dataset setup part.")
# --- 3. RUN PREDICTION ---
with torch.no_grad():
# Forward Pass
logits = model(images)
probs = torch.sigmoid(logits)
pred_mask = (probs > 0.5).float() # Threshold at 50% confidence
# --- 4. VISUALIZATION ---
# Helper to normalize RGB for display
def get_rgb(img_tensor, frame_idx=6):
# Shape: (B, T, C, H, W) -> (10, H, W) for one frame
# Bands: ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
# Indices: B4(Red)=2, B3(Green)=1, B2(Blue)=0
img = img_tensor[0, frame_idx].cpu().numpy()
rgb = img[[2, 1, 0], :, :] # Pick RGB bands
# Normalize for display (Sentinel-2 is roughly 0-0.3 reflectance typically, here scaled 0-1)
rgb = np.clip(rgb * 3.0, 0, 1) # Brighten it up x3
return np.transpose(rgb, (1, 2, 0)) # CHW -> HWC
# Prepare plots
fig, ax = plt.subplots(1, 3, figsize=(18, 6))
# A. Satellite Image (RGB from Frame 6)
ax[0].imshow(get_rgb(images))
ax[0].set_title("Sentinel-2 Input (RGB - Frame 6)")
ax[0].axis('off')
# B. Ground Truth Mask
ax[1].imshow(masks[0, 0].cpu().numpy(), cmap='gray')
ax[1].set_title("Ground Truth (Wheat Mask)")
ax[1].axis('off')
# C. Model Prediction
ax[2].imshow(pred_mask[0, 0].cpu().numpy(), cmap='jet', alpha=1.0)
ax[2].set_title("SatMAE Prediction")
ax[2].axis('off')
plt.tight_layout()
plt.show()
print(f"Sample IoU: {compute_metrics(probs, masks)[0]:.4f}")
Device: cuda Loading weights from: final_swa_model.pth ✅ Model Loaded Successfully!
Sample IoU: 0.3799
In [13]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
# --- CONFIG ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
MODEL_PATH = os.path.join(SAVE_DIR, "final_swa_model.pth")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ==========================================
# PART 1: PLOT TRAINING HISTORY
# ==========================================
if os.path.exists(LOG_PATH):
try:
df = pd.read_csv(LOG_PATH)
# Setup Plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Plot Loss
ax1.plot(df['epoch'], df['train_loss'], label='Train Loss', color='orange', linewidth=2)
ax1.plot(df['epoch'], df['val_loss'], label='Val Loss', color='blue', linewidth=2)
ax1.set_title("Training vs Validation Loss")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot IoU
ax2.plot(df['epoch'], df['val_iou'], label='Val IoU', color='green', linewidth=2)
ax2.set_title("Validation IoU (Quality) Over Time")
ax2.set_xlabel("Epochs")
ax2.set_ylabel("IoU Score")
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(" Historical graphs generated from log file.")
except Exception as e:
print(f" Could not plot graphs: {e}")
else:
print(" No log file found. Cannot plot history.")
# ==========================================
# PART 2: CALCULATE DEEP METRICS (F1, PRECISION, RECALL)
# ==========================================
print("\n" + "="*40)
print(" STARTING DEEP METRIC EVALUATION...")
print("="*40)
# 1. Load Model
try:
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print(" Final Model Loaded.")
except Exception as e:
print(f" Error loading model: {e}")
# Stop execution if model fails
raise RuntimeError("Model needed for metrics.")
# 2. Setup Data
try:
# Use existing validation dataset
eval_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
except NameError:
print(" Error: 'val_ds' not defined. Please run the Dataset Setup cell first.")
# 3. Accumulate Pixels (Global Calculation)
# We flatten all predictions and targets into one giant list for accurate global stats
all_preds = []
all_targets = []
print(f" Evaluating over {len(eval_loader)} validation samples...")
with torch.no_grad():
for images, masks in eval_loader:
images = images.to(device)
# Predict
logits = model(images)
probs = torch.sigmoid(logits)
preds = (probs > 0.5).cpu().numpy().flatten().astype(np.uint8)
# Ground Truth
targets = masks.cpu().numpy().flatten().astype(np.uint8)
all_preds.append(preds)
all_targets.append(targets)
# Concatenate all pixels
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
# 4. Compute Metrics
# Note: Zero_division=0 handles cases where model predicts nothing (safe division)
precision = precision_score(all_targets, all_preds, zero_division=0)
recall = recall_score(all_targets, all_preds, zero_division=0) # Sensitivity
f1 = f1_score(all_targets, all_preds, zero_division=0)
accuracy = accuracy_score(all_targets, all_preds)
# IoU Calculation (Jaccard)
intersection = np.logical_and(all_targets, all_preds).sum()
union = np.logical_or(all_targets, all_preds).sum()
iou = intersection / union if union > 0 else 0.0
# Confusion Matrix Elements
tn, fp, fn, tp = confusion_matrix(all_targets, all_preds).ravel()
# ==========================================
# PART 3: FINAL REPORT
# ==========================================
print("\n" + "="*40)
print(" FINAL MODEL PERFORMANCE REPORT")
print("="*40)
print(f"IoU (Intersection over Union): {iou:.4f} (Main Metric)")
print(f"F1 Score (Dice Coefficient): {f1:.4f} (Harmonic Mean)")
print("-" * 40)
print(f"Precision (True Pos / Pred Pos):{precision:.4f} (How trustworthy 'Wheat' is)")
print(f"Recall (True Pos / Actual Pos): {recall:.4f} (How much Wheat we found)")
print(f"Accuracy (Pixel-wise): {accuracy:.4f}")
print("-" * 40)
print("Confusion Matrix (Pixels):")
print(f" True Wheat Found (TP): {tp}")
print(f" Background Found (TN): {tn}")
print(f" Wheat Missed (FN): {fn}")
print(f" False Alarm (FP): {fp}")
print("="*40)
Could not plot graphs: 'epoch' ======================================== STARTING DEEP METRIC EVALUATION... ======================================== Final Model Loaded. ⏳ Evaluating over 3 validation samples... ======================================== FINAL MODEL PERFORMANCE REPORT ======================================== IoU (Intersection over Union): 0.3686 (Main Metric) F1 Score (Dice Coefficient): 0.5386 (Harmonic Mean) ---------------------------------------- Precision (True Pos / Pred Pos):0.6618 (How trustworthy 'Wheat' is) Recall (True Pos / Actual Pos): 0.4541 (How much Wheat we found) Accuracy (Pixel-wise): 0.4850 ---------------------------------------- Confusion Matrix (Pixels): True Wheat Found (TP): 45250 Background Found (TN): 27763 Wheat Missed (FN): 54387 False Alarm (FP): 23128 ========================================