In [6]:
import ee
import geemap
import datetime
import os
import numpy as np
import rasterio
import time
from sklearn.model_selection import train_test_split
!pip install -q geedim geemap

# --- 1. INITIALIZATION ---
try:
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
    ee.Authenticate()
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')

# Configuration
year = 2021
START_DATE = f'{year}-10-16'
END_DATE = f'{year + 1}-04-16'
ASSET_ID = '[REDACTED_FOR_SECURITY]'
DRIVE_FOLDER = 'SatMAE_Scratch_Results_12Frames'
SAVE_DIR = f'/content/drive/MyDrive/{DRIVE_FOLDER}/'

if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

# Assets
wheat_mask = ee.Image(ASSET_ID)
roi = wheat_mask.geometry()
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
MAX_CLOUD_PROB = 70

# Sentinel-2 Processing
s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
         .filterDate(START_DATE, END_DATE).filterBounds(roi)
         .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100))
         .map(lambda img: img.updateMask(
             img.select('MSK_SNWPRB').lt(1)
             .add(img.select('SCL').gte(4).And(img.select('SCL').lte(6)))
         )))

s2Clouds = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY').filterDate(START_DATE, END_DATE).filterBounds(roi)
s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
    s2Raw, s2Clouds, ee.Filter.equals(leftField='system:index', rightField='system:index')))

def processS2(img):
    isCloud = ee.Image(img.get('cloud_mask_img')).select('probability').gt(MAX_CLOUD_PROB)
    optical = img.select(s2Bands).multiply(0.025)
    ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
    combined_mask = isCloud.Not()
    return (optical.addBands(ndvi).updateMask(combined_mask).copyProperties(img, ['system:time_start']))

s2Clean = s2Joined.map(processS2)

# Fortnightly Composites
intervals = []
cur = datetime.datetime.fromisoformat(START_DATE)
end = datetime.datetime.fromisoformat(END_DATE)
while cur < end:
    if cur.day == 1:
        nxt = cur + datetime.timedelta(days=14)
        str_lbl = f"{cur.year}_{cur.month:02d}_1"
    else:
        next_month = cur.replace(day=1) + datetime.timedelta(days=32)
        nxt = next_month.replace(day=1)
        str_lbl = f"{cur.year}_{cur.month:02d}_2"

    if nxt > end: nxt = end
    intervals.append([cur.strftime('%Y-%m-%d'), (nxt).strftime('%Y-%m-%d'), str_lbl])
    cur = nxt

intervals = intervals[:12]
print(f"Generating {len(intervals)} Time Frames...")

def makeComposite(item):
    start, end = ee.Date(ee.List(item).get(0)), ee.Date(ee.List(item).get(1))
    return s2Clean.filterDate(start, end).qualityMosaic('NDVI').select(s2Bands)\
        .clamp(0, 250).toByte().unmask(255)\
        .set('fortnight_label', ee.List(item).get(2))

fortnightlyCol = ee.ImageCollection.fromImages(ee.List(intervals).map(makeComposite))

# --- ROBUST EXPORT LOGIC ---
def monitor_task(task):
    print(f"Submitting Task: {task.config['description']}...")
    try:
        task.start()
    except Exception as e:
        if "already exists" in str(e):
            print("Task already running/completed.")
            return
        raise e

    while True:
        status = task.status()
        state = status['state']
        if state in ['COMPLETED', 'FAILED', 'CANCELLED']:
            print(f"\nTask Finished with state: {state}")
            if state == 'FAILED':
                print(f"Error Message: {status.get('error_message', 'Unknown Error')}")
                raise RuntimeError("GEE Export Failed.")
            break
        print(f"Status: {state}...", end='\r')
        time.sleep(10)

def wait_for_file(filepath, timeout=300): # INCREASED TO 5 MINUTES
    """Waits for a file to appear in Colab Drive mount with Forced Refresh."""
    print(f"Waiting for Drive sync: {os.path.basename(filepath)}...")
    start = time.time()

    while not os.path.exists(filepath):
        elapsed = time.time() - start
        if elapsed > timeout:
            print(f"\n TIMEOUT: File {filepath} not found locally.")
            print("Check your Google Drive folder in a new tab.")
            print("If the file exists there, simply RE-RUN this cell.")
            raise FileNotFoundError(f"Timeout: {filepath} did not appear after {timeout}s.")

        # AGGRESSIVE REFRESH LOGIC
        # We list the directory contents to force the OS to check the drive again
        try:
            _ = os.listdir(os.path.dirname(filepath))
        except:
            pass

        print(f"Syncing... ({int(elapsed)}s)", end='\r')
        time.sleep(5)

    print(f"\n Found: {filepath}")
    return True

def generate_local_dataset():
    print("Starting Robust Data Generation (Batch Export)...")
    centroid = roi.centroid()
    buffer_poly = centroid.buffer(2500).bounds()

    # 1. Export Stack
    stack_name = 'local_stack'
    stack_file = os.path.join(SAVE_DIR, f'{stack_name}.tif')

    # Only submit if file missing
    if not os.path.exists(stack_file):
        task_stack = ee.batch.Export.image.toDrive(
            image=fortnightlyCol.toBands(),
            description='Export_Stack_12Frames',
            folder=DRIVE_FOLDER,
            fileNamePrefix=stack_name,
            region=buffer_poly,
            scale=10,
            crs='EPSG:4326',
            fileFormat='GeoTIFF',
            maxPixels=1e9
        )
        monitor_task(task_stack)
    else:
        print(f"Found existing stack: {stack_file}")

    # 2. Export Mask
    mask_name = 'local_mask'
    mask_file = os.path.join(SAVE_DIR, f'{mask_name}.tif')

    if not os.path.exists(mask_file):
        task_mask = ee.batch.Export.image.toDrive(
            image=wheat_mask,
            description='Export_Mask',
            folder=DRIVE_FOLDER,
            fileNamePrefix=mask_name,
            region=buffer_poly,
            scale=10,
            crs='EPSG:4326',
            fileFormat='GeoTIFF',
            maxPixels=1e9
        )
        monitor_task(task_mask)
    else:
        print(f"Found existing mask: {mask_file}")

    # 3. Sync Wait (Crucial Step)
    wait_for_file(stack_file)
    wait_for_file(mask_file)

    # 4. Process
    print("Processing GeoTIFFs to Numpy...")

    with rasterio.open(stack_file) as src_stack, rasterio.open(mask_file) as src_mask:
        img_data = src_stack.read()
        mask_data = src_mask.read(1)

        H, W = mask_data.shape
        if img_data.shape[0] < 120:
             print(f"Warning: Got {img_data.shape[0]} bands. Padding with zeros.")
             padding = np.zeros((120 - img_data.shape[0], H, W), dtype=img_data.dtype)
             img_data = np.concatenate([img_data, padding], axis=0)

        img_reshaped = img_data[:120].reshape(12, 10, H, W)

        X_list, y_list = [], []
        if H < 224 or W < 224:
             print("ROI too small.")
             return
        else:
            for r in range(0, H-224, 112):
                for c in range(0, W-224, 112):
                    chip_x = img_reshaped[:, :, r:r+224, c:c+224]
                    chip_y = mask_data[r:r+224, c:c+224]

                    if np.mean(chip_x == 255) > 0.30: continue
                    if np.mean(chip_y > 0) > 0.01:
                        X_list.append(chip_x)
                        y_list.append(chip_y)

        if len(X_list) == 0: raise RuntimeError("No valid tiles found.")

        X = np.array(X_list).astype(np.uint8)
        y = np.array(y_list).astype(np.uint8)[:, None, :, :]

        split_idx = int(0.8 * len(X))
        X_train, X_val = X[:split_idx], X[split_idx:]
        y_train, y_val = y[:split_idx], y[split_idx:]

        print(f"Spatial Split: Train={len(X_train)}, Val={len(X_val)}")

        np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
        np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
        np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
        np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
        print("Done.")

generate_local_dataset()
Generating 12 Time Frames...
Starting Robust Data Generation (Batch Export)...
Submitting Task: Export_Stack_12Frames...

Task Finished with state: COMPLETED
Submitting Task: Export_Mask...

Task Finished with state: COMPLETED
Waiting for Drive sync: local_stack.tif...

❌ TIMEOUT: File /content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/local_stack.tif not found locally.
Check your Google Drive folder in a new tab.
If the file exists there, simply RE-RUN this cell.
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-3111439864.py in <cell line: 0>()
    229         print("Done.")
    230 
--> 231 generate_local_dataset()

/tmp/ipython-input-3111439864.py in generate_local_dataset()
    179 
    180     # 3. Sync Wait (Crucial Step)
--> 181     wait_for_file(stack_file)
    182     wait_for_file(mask_file)
    183 

/tmp/ipython-input-3111439864.py in wait_for_file(filepath, timeout)
    117             print("Check your Google Drive folder in a new tab.")
    118             print("If the file exists there, simply RE-RUN this cell.")
--> 119             raise FileNotFoundError(f"Timeout: {filepath} did not appear after {timeout}s.")
    120 
    121         # AGGRESSIVE REFRESH LOGIC

FileNotFoundError: Timeout: /content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/local_stack.tif did not appear after 300s.
In [1]:
import ee
import geemap
import os
import numpy as np
import rasterio
from google.colab import drive

# --- 1. MOUNT DRIVE ---
# This allows us to access the files you already saved
drive.mount('/content/drive')

# --- 2. INITIALIZE EARTH ENGINE ---
try:
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')
    print(" Earth Engine Initialized.")
except:
    ee.Authenticate()
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')
    print(" Earth Engine Authenticated & Initialized.")

# --- 3. CONFIGURATION ---
DRIVE_FOLDER = 'SatMAE_Scratch_Results_12Frames'
SAVE_DIR = f'/content/drive/MyDrive/{DRIVE_FOLDER}/'

stack_file = os.path.join(SAVE_DIR, 'local_stack.tif')
mask_file = os.path.join(SAVE_DIR, 'local_mask.tif')

# --- 4. PROCESSING LOGIC (TIF -> NPY) ---
def process_existing_files():
    if not os.path.exists(stack_file) or not os.path.exists(mask_file):
        print(f" Error: Files not found in {SAVE_DIR}")
        print("Please check your Google Drive folder name.")
        return

    print(f" Found Stack: {os.path.basename(stack_file)}")
    print(f" Found Mask:  {os.path.basename(mask_file)}")
    print(" Processing into Training Data (Spatial Split + NoData Filter)...")

    with rasterio.open(stack_file) as src_stack, rasterio.open(mask_file) as src_mask:
        img_data = src_stack.read()
        mask_data = src_mask.read(1)

        # Padding Check (Ensure 120 channels: 12 frames * 10 bands)
        H, W = mask_data.shape
        if img_data.shape[0] < 120:
             print(f"   Warning: Got {img_data.shape[0]} bands. Padding with zeros.")
             padding = np.zeros((120 - img_data.shape[0], H, W), dtype=img_data.dtype)
             img_data = np.concatenate([img_data, padding], axis=0)

        img_reshaped = img_data[:120].reshape(12, 10, H, W)

        # Chip Generation
        X_list, y_list = [], []
        if H < 224 or W < 224:
             print(" Error: ROI is smaller than 224x224.")
             return
        else:
            # Stride 112 = 50% Overlap
            for r in range(0, H-224, 112):
                for c in range(0, W-224, 112):
                    chip_x = img_reshaped[:, :, r:r+224, c:c+224]
                    chip_y = mask_data[r:r+224, c:c+224]

                    # Filter Garbage (Discard if >30% is NoData/255)
                    if np.mean(chip_x == 255) > 0.30: continue

                    # Filter Empty Labels (Keep if >1% wheat)
                    if np.mean(chip_y > 0) > 0.01:
                        X_list.append(chip_x)
                        y_list.append(chip_y)

        if len(X_list) == 0:
            print(" No valid tiles found! (Mask might be empty or threshold too strict).")
            return

        # Spatial Split (Top 80% Train, Bottom 20% Val)
        X = np.array(X_list).astype(np.uint8)
        y = np.array(y_list).astype(np.uint8)[:, None, :, :]

        split_idx = int(0.8 * len(X))
        X_train, X_val = X[:split_idx], X[split_idx:]
        y_train, y_val = y[:split_idx], y[split_idx:]

        print("-" * 30)
        print(f" Data Ready!")
        print(f"   Train Tiles: {len(X_train)}")
        print(f"   Val Tiles:   {len(X_val)}")

        # Save to Drive
        np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
        np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
        np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
        np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
        print(" .npy files saved successfully.")

# Run the processing
process_existing_files()
Mounted at /content/drive
 Earth Engine Authenticated & Initialized.
 Found Stack: local_stack.tif
 Found Mask:  local_mask.tif
 Processing into Training Data (Spatial Split + NoData Filter)...
------------------------------
✅ Data Ready!
   Train Tiles: 9
   Val Tiles:   3
💾 .npy files saved successfully.
In [3]:
import os
import numpy as np
import rasterio
from google.colab import drive

# --- 1. MOUNT DRIVE ---
# We need to see your Google Drive files
drive.mount('/content/drive', force_remount=True)

# --- 2. CONFIGURATION ---
# This must match exactly where your files are
DRIVE_FOLDER = 'SatMAE_Scratch_Results_12Frames'
SAVE_DIR = f'/content/drive/MyDrive/{DRIVE_FOLDER}/'

stack_file = os.path.join(SAVE_DIR, 'local_stack.tif')
mask_file = os.path.join(SAVE_DIR, 'local_mask.tif')

# --- 3. PROCESSING FUNCTION (No GEE required) ---
def generate_dataset_from_drive():
    print(f" Checking folder: {SAVE_DIR}")

    if not os.path.exists(stack_file) or not os.path.exists(mask_file):
        print(f" Error: Files not found!")
        print(f"   Looking for: {stack_file}")
        print("   Please check if the folder name is correct.")
        return

    print(f"Found Stack: {os.path.basename(stack_file)}")
    print(f" Found Mask:  {os.path.basename(mask_file)}")
    print(" Processing GeoTIFFs into NPY (Training Data)...")

    # Open the files directly from Drive
    with rasterio.open(stack_file) as src_stack, rasterio.open(mask_file) as src_mask:
        img_data = src_stack.read()
        mask_data = src_mask.read(1)

        # 1. Padding Check (Ensure 120 channels)
        H, W = mask_data.shape
        if img_data.shape[0] < 120:
             print(f"   Warning: Got {img_data.shape[0]} bands. Padding with zeros.")
             padding = np.zeros((120 - img_data.shape[0], H, W), dtype=img_data.dtype)
             img_data = np.concatenate([img_data, padding], axis=0)

        img_reshaped = img_data[:120].reshape(12, 10, H, W)

        # 2. Chip Generation
        X_list, y_list = [], []

        if H < 224 or W < 224:
             print(" Error: Image is smaller than 224x224. Cannot create tiles.")
             return
        else:
            # Stride 112 = 50% Overlap
            for r in range(0, H-224, 112):
                for c in range(0, W-224, 112):
                    chip_x = img_reshaped[:, :, r:r+224, c:c+224]
                    chip_y = mask_data[r:r+224, c:c+224]

                    # Filter Garbage (255)
                    if np.mean(chip_x == 255) > 0.30: continue
                    # Filter Empty Labels
                    if np.mean(chip_y > 0) > 0.01:
                        X_list.append(chip_x)
                        y_list.append(chip_y)

        if len(X_list) == 0:
            print(" No valid tiles found! (Check if mask has wheat pixels).")
            return

        # 3. Spatial Split (Top 80% Train, Bottom 20% Val)
        X = np.array(X_list).astype(np.uint8)
        y = np.array(y_list).astype(np.uint8)[:, None, :, :]

        split_idx = int(0.8 * len(X))
        X_train, X_val = X[:split_idx], X[split_idx:]
        y_train, y_val = y[:split_idx], y[split_idx:]

        print("-" * 30)
        print(f" Data Processed Successfully!")
        print(f"   Train Chips: {len(X_train)}")
        print(f"   Val Chips:   {len(X_val)}")

        # 4. Save NPY files back to Drive
        np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
        np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
        np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
        np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
        print(" .npy files saved.")

# --- RUN IT ---
generate_dataset_from_drive()
Mounted at /content/drive
 Checking folder: /content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/
Found Stack: local_stack.tif
 Found Mask:  local_mask.tif
 Processing GeoTIFFs into NPY (Training Data)...
------------------------------
 Data Processed Successfully!
   Train Chips: 9
   Val Chips:   3
 .npy files saved.
In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from tqdm.auto import tqdm

# --- 1. METRICS ---
def compute_metrics(pred_probs, targets, threshold=0.5):
    pred_mask = (pred_probs > threshold).float()
    intersection = (pred_mask * targets).sum()
    union = pred_mask.sum() + targets.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
    return iou.item(), dice.item()

# --- 2. LOSS FUNCTIONS ---
class FastHausdorffLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, gt):
        probs = torch.sigmoid(pred)

        # Morphological Erosion via MaxPool
        # Invert -> MaxPool -> Invert = Erosion
        p_eroded = -F.max_pool2d(-probs, kernel_size=3, stride=1, padding=1)
        t_eroded = -F.max_pool2d(-gt, kernel_size=3, stride=1, padding=1)

        p_edge = probs - p_eroded
        t_edge = gt - t_eroded

        # --- MATH FIX: SCALING ---
        # Edge differences are sparse (mostly 0). The mean is tiny.
        # We multiply by 20.0 to bring the magnitude to ~0.1-0.5 range
        # so it balances with Dice Loss.
        return (p_edge - t_edge).abs().mean() * 20.0

class CompoundLoss(nn.Module):
    def forward(self, inputs, targets):
        # Dice
        probs = torch.sigmoid(inputs)
        inter = (probs * targets).sum()
        dice_loss = 1 - (2. * inter / (probs.sum() + targets.sum() + 1e-6))

        # Fast Hausdorff
        fast_hd = FastHausdorffLoss()(inputs, targets)

        # 70% Overlap, 30% Shape
        return 0.7 * dice_loss + 0.3 * fast_hd

class TestTimeAugmentation:
    def __init__(self, model): self.model = model
    def apply(self, x):
        preds = []
        with torch.no_grad():
            preds.append(torch.sigmoid(self.model(x)))
            preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-1]))), [-1]))
            preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2]))), [-2]))
            preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2,-1]))), [-2,-1]))
        return torch.stack(preds).mean(dim=0)

# --- 3. ARCHITECTURE (Standard SatMAE) ---
class SatMAEPatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=10, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B * T, C, H, W)
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x.reshape(B, T, -1, 768)

class SatMAEEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_embed = SatMAEPatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.pos_embed_spatial = nn.Parameter(torch.zeros(1, 1, self.patch_embed.num_patches, embed_dim))
        self.pos_embed_temporal = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))

        nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)

        self.blocks = nn.TransformerEncoder(nn.TransformerEncoderLayer(embed_dim, num_heads, int(embed_dim*4), 0.1, 'gelu', batch_first=True), depth)
        self.norm = nn.LayerNorm(embed_dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0); nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = self.patch_embed(x)
        x = x + self.pos_embed_spatial + self.pos_embed_temporal
        x = x.reshape(B, T * x.shape[2], -1)
        x = self.norm(self.blocks(x))
        return x.reshape(B, T, 14, 14, -1)

class SatMAESegmentation(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768):
        super().__init__()
        self.num_frames = num_frames; self.chans = in_chans
        self.encoder = SatMAEEncoder(img_size, patch_size, in_chans, num_frames, embed_dim)

        self.up1 = nn.ConvTranspose2d(embed_dim, 256, 2, 2)
        self.conv1 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU())
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.conv2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU())
        self.up3 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU())
        self.up4 = nn.ConvTranspose2d(64, 32, 2, 2)
        self.conv4 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU())
        self.final = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        if x.ndim == 4: B, _, H, W = x.shape; x = x.reshape(B, self.num_frames, self.chans, H, W)
        features = self.encoder(x)
        x = features.mean(dim=1).permute(0, 3, 1, 2)
        x = self.conv1(self.up1(x))
        x = self.conv2(self.up2(x))
        x = self.conv3(self.up3(x))
        x = self.conv4(self.up4(x))
        return self.final(x)

# --- 4. TRAINER ---
class SatMAETrainer:
    def __init__(self, model, loaders, device, lr=1e-4):
        self.model = model; self.loaders = loaders; self.device = device
        self.optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=150)
        self.swa_model = AveragedModel(model)
        self.swa_scheduler = SWALR(self.optimizer, swa_lr=lr)
        self.criterion = CompoundLoss()
        self.tta = TestTimeAugmentation(model)
        self.history = {'train_loss': [], 'val_iou': []}

    def fit(self, epochs=200):
        print(f"Starting Scratch Training ({epochs} epochs)...")
        use_swa = False; swa_start_epoch = 151; patience = 30; no_improv = 0; best_iou = 0

        for ep in range(epochs):
            self.model.train()
            t_loss = 0
            pbar = tqdm(self.loaders['train'], desc=f"Epoch {ep+1}/{epochs}", leave=False)

            for batch in pbar:
                if isinstance(batch, (list, tuple)): x, y = batch[0], batch[1]
                else: x, y = batch
                x, y = x.to(self.device), y.to(self.device)

                self.optimizer.zero_grad()
                logits = self.model(x)
                loss = self.criterion(logits, y)
                loss.backward()
                self.optimizer.step()
                t_loss += loss.item()
                pbar.set_postfix({'loss': f"{loss.item():.4f}"})

            # Validation
            eval_model = self.swa_model if use_swa else self.model
            eval_model.eval()
            v_iou_sum = 0

            with torch.no_grad():
                for batch in self.loaders['val']:
                    if isinstance(batch, (list, tuple)): x, y = batch[0], batch[1]
                    else: x, y = batch
                    x, y = x.to(self.device), y.to(self.device)

                    if use_swa: probs = torch.sigmoid(eval_model(x))
                    else: probs = self.tta.apply(x)
                    iou, _ = compute_metrics(probs, y)
                    v_iou_sum += iou

            avg_loss = t_loss / len(self.loaders['train'])
            avg_iou = v_iou_sum / len(self.loaders['val']) if len(self.loaders['val']) > 0 else 0

            self.history['train_loss'].append(avg_loss)
            self.history['val_iou'].append(avg_iou)
            print(f"Ep {ep+1} | Loss: {avg_loss:.4f} | Val IoU: {avg_iou:.4f} | SWA: {use_swa}")

            if (ep + 1) % 5 == 0: torch.save(self.model.state_dict(), f"checkpoint_ep{ep+1}.pth")

            if not use_swa:
                self.scheduler.step()
                if avg_iou > best_iou:
                    best_iou = avg_iou; no_improv = 0
                    torch.save(self.model.state_dict(), 'best_model.pth')
                    print(f"   >>> New Best IoU: {best_iou:.4f}")
                else:
                    no_improv += 1
                    if no_improv >= patience:
                        print(f"   !! No Improvement. Triggering SWA."); use_swa = True; swa_start_epoch = ep + 1
                if ep + 1 >= swa_start_epoch: use_swa = True
            else:
                self.swa_model.update_parameters(self.model); self.swa_scheduler.step()

        if use_swa:
            update_bn(self.loaders['train'], self.swa_model, device=self.device)
            torch.save(self.swa_model.module.state_dict(), 'final_swa_model.pth')
            print("Training Complete. SWA Model Saved.")
In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint # <--- CRITICAL IMPORT FOR MEMORY

# --- 1. METRICS ---
def compute_metrics(pred_probs, targets, threshold=0.5):
    pred_mask = (pred_probs > threshold).float()
    intersection = (pred_mask * targets).sum()
    union = pred_mask.sum() + targets.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
    return iou.item(), dice.item()

# --- 2. LOSS FUNCTIONS ---
class FastHausdorffLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, pred, gt):
        probs = torch.sigmoid(pred)
        p_eroded = -F.max_pool2d(-probs, kernel_size=3, stride=1, padding=1)
        t_eroded = -F.max_pool2d(-gt, kernel_size=3, stride=1, padding=1)
        return (probs - p_eroded - (gt - t_eroded)).abs().mean() * 20.0

class CompoundLoss(nn.Module):
    def forward(self, inputs, targets):
        probs = torch.sigmoid(inputs)
        inter = (probs * targets).sum()
        dice = 1 - (2. * inter / (probs.sum() + targets.sum() + 1e-6))
        hd = FastHausdorffLoss()(inputs, targets)
        return 0.7 * dice + 0.3 * hd

class TestTimeAugmentation:
    def __init__(self, model): self.model = model
    def apply(self, x):
        preds = []
        with torch.no_grad():
            preds.append(torch.sigmoid(self.model(x)))
            preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-1]))), [-1]))
            preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2]))), [-2]))
            preds.append(torch.flip(torch.sigmoid(self.model(torch.flip(x, [-2,-1]))), [-2,-1]))
        return torch.stack(preds).mean(dim=0)

# --- 3. ARCHITECTURE (Memory Optimized) ---
class SatMAEPatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=10, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B * T, C, H, W)
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x.reshape(B, T, -1, 768)

class SatMAEEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.patch_embed = SatMAEPatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.pos_embed_spatial = nn.Parameter(torch.zeros(1, 1, self.patch_embed.num_patches, embed_dim))
        self.pos_embed_temporal = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))

        nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)

        # CHANGED: Use ModuleList to enable Checkpointing
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(embed_dim, num_heads, int(embed_dim*4), 0.1, 'gelu', batch_first=True)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0); nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = self.patch_embed(x)
        x = x + self.pos_embed_spatial + self.pos_embed_temporal
        x = x.reshape(B, T * x.shape[2], -1)

        # --- GRADIENT CHECKPOINTING LOGIC ---
        for blk in self.blocks:
            if self.training:
                # This saves massive memory by not storing activations
                x = checkpoint(blk, x, use_reentrant=False)
            else:
                x = blk(x)

        x = self.norm(x)
        return x.reshape(B, T, 14, 14, -1)

class SatMAESegmentation(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=10, num_frames=12, embed_dim=768):
        super().__init__()
        self.num_frames = num_frames; self.chans = in_chans
        self.encoder = SatMAEEncoder(img_size, patch_size, in_chans, num_frames, embed_dim)

        self.up1 = nn.ConvTranspose2d(embed_dim, 256, 2, 2)
        self.conv1 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU())
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.conv2 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU())
        self.up3 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU())
        self.up4 = nn.ConvTranspose2d(64, 32, 2, 2)
        self.conv4 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU())
        self.final = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        if x.ndim == 4: B, _, H, W = x.shape; x = x.reshape(B, self.num_frames, self.chans, H, W)
        features = self.encoder(x)
        x = features.mean(dim=1).permute(0, 3, 1, 2)
        x = self.conv1(self.up1(x))
        x = self.conv2(self.up2(x))
        x = self.conv3(self.up3(x))
        x = self.conv4(self.up4(x))
        return self.final(x)
In [5]:
import torch
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
import os
import csv
import pandas as pd
import matplotlib.pyplot as plt

# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")

BATCH_SIZE = 16
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50

# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

# --- SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Model (12 Frames, 12 Channels)
model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=12, num_frames=12).to(device)
criterion = CompoundLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)

# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)


loaders = {
    'train': DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
    'val': DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
}

# --- RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}

if os.path.exists(CHECKPOINT_PATH):
    print(" Found checkpoint. Resuming...")
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_iou = ckpt['best_iou']
    patience_counter = ckpt['patience_counter']
    swa_active = ckpt['swa_active']
    swa_epoch_counter = ckpt['swa_epoch_counter']
    if swa_active:
        swa_model.load_state_dict(ckpt['swa_state_dict'])
        swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])

    # Load history from CSV
    if os.path.exists(LOG_PATH):
        df = pd.read_csv(LOG_PATH)
        history['epoch'] = df['epoch'].tolist()
        history['train_loss'] = df['train_loss'].tolist()
        history['val_loss'] = df['val_loss'].tolist()
        history['val_iou'] = df['val_iou'].tolist()
    print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")

# --- TRAINING LOOP ---
print(f" Starting Training. Max: {MAX_EPOCHS} Eps. Patience: {PATIENCE_TRIGGER}")

try:
    for ep in range(start_epoch, MAX_EPOCHS):
        model.train()
        train_loss = 0

        # Training Step
        for x, y in tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds = model(x)
            loss = criterion(preds, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation Step
        eval_model = swa_model if swa_active else model
        eval_model.eval()
        val_loss = 0
        val_iou_sum = 0

        with torch.no_grad():
            for x, y in loaders['val']:
                x, y = x.to(device), y.to(device)

                # Forward
                if swa_active: preds = eval_model(x)
                else: preds = model(x) # No TTA for speed during training loop

                val_loss += criterion(preds, y).item()

                # Metrics
                probs = torch.sigmoid(preds)
                iou, _ = compute_metrics(probs, y)
                val_iou_sum += iou

        # Stats
        avg_t = train_loss / len(loaders['train'])
        avg_v = val_loss / len(loaders['val'])
        avg_iou = val_iou_sum / len(loaders['val'])

        # Log to History
        history['epoch'].append(ep+1)
        history['train_loss'].append(avg_t)
        history['val_loss'].append(avg_v)
        history['val_iou'].append(avg_iou)

        # Append to CSV immediately (Crash Safety)
        with open(LOG_PATH, 'a', newline='') as f:
            writer = csv.writer(f)
            if ep == 0 and not os.path.exists(LOG_PATH):
                writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
            writer.writerow([ep+1, avg_t, avg_v, avg_iou])

        # Status Message
        status = ""

        # --- LOGIC BRANCHING ---
        if swa_active:
            # SWA PHASE
            swa_model.update_parameters(model)
            swa_scheduler.step()
            swa_epoch_counter += 1
            status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"

            # Stop Condition
            if swa_epoch_counter >= SWA_DURATION:
                print(f"Epoch {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
                print(" SWA Complete. Saving Final Model.")
                update_bn(loaders['train'], swa_model, device=device)
                torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
                break # EXIT LOOP
        else:
            # NORMAL PHASE
            scheduler.step()

            if avg_iou > best_iou:
                best_iou = avg_iou
                patience_counter = 0
                torch.save(model.state_dict(), BEST_MODEL_PATH)
                status = f" Best IoU!"
            else:
                patience_counter += 1
                status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"

            # Trigger SWA?
            if patience_counter >= PATIENCE_TRIGGER:
                print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
                swa_active = True
                swa_epoch_counter = 0
                # Load best weights before starting SWA to ensure stability
                model.load_state_dict(torch.load(BEST_MODEL_PATH))

        print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")


        torch.save({
            'epoch': ep,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'swa_state_dict': swa_model.state_dict() if swa_active else None,
            'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
            'best_iou': best_iou,
            'patience_counter': patience_counter,
            'swa_active': swa_active,
            'swa_epoch_counter': swa_epoch_counter
        }, CHECKPOINT_PATH)

except KeyboardInterrupt:
    print("Training Interrupted. Checkpoint Saved.")

# Plot Results
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.legend()
plt.title("Training Curves")
plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
plt.show()
Device: cuda
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipython-input-2473747799.py in <cell line: 0>()
     37 
     38 loaders = {
---> 39     'train': DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
     40     'val': DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
     41 }

NameError: name 'train_ds' is not defined
In [6]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")

BATCH_SIZE = 16
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50

# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

# --- 1. DATASET DEFINITION (INCLUDED HERE TO PREVENT ERRORS) ---
class SatMAEDataset(Dataset):
    def __init__(self, x_path, y_path):
        # Load in mmap mode (Zero RAM usage initially)
        self.x_data = np.load(x_path, mmap_mode='r')
        self.y_data = np.load(y_path, mmap_mode='r')

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        # Read into RAM
        x_np = self.x_data[idx]
        y_np = self.y_data[idx]

        # Fast Numpy Scaling
        x_float = x_np.astype(np.float32)
        y_float = y_np.astype(np.float32)

        # Zero out '255' (NoData)
        x_float[x_np == 255] = 0.0
        y_float[y_np == 255] = 0.0

        # Scale 0-250 -> 0-1
        x_float /= 250.0

        return torch.from_numpy(x_float), torch.from_numpy(y_float)

# --- 2. SETUP LOADERS ---
try:
    print("Initializing Datasets...")
    train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
    val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))

    loaders = {
        'train': DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
        'val': DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    }
    print(" Loaders Ready.")
except Exception as e:
    print(f" Error loading data: {e}")
    print("Did you run Cell 1 to generate the .npy files?")
    raise e

# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Dependency Check: Ensure Cell 2 was run
try:
    model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
    criterion = CompoundLoss()
except NameError:
    raise NameError(" Error: 'SatMAESegmentation' or 'CompoundLoss' not defined. Please RUN CELL 2 (Architecture) first.")

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)

# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)

# --- 4. RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}

if os.path.exists(CHECKPOINT_PATH):
    print(" Found checkpoint. Resuming...")
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_iou = ckpt['best_iou']
    patience_counter = ckpt['patience_counter']
    swa_active = ckpt['swa_active']
    swa_epoch_counter = ckpt['swa_epoch_counter']
    if swa_active:
        swa_model.load_state_dict(ckpt['swa_state_dict'])
        swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])

    if os.path.exists(LOG_PATH):
        df = pd.read_csv(LOG_PATH)
        history['epoch'] = df['epoch'].tolist()
        history['train_loss'] = df['train_loss'].tolist()
        history['val_loss'] = df['val_loss'].tolist()
        history['val_iou'] = df['val_iou'].tolist()
    print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")

# --- 5. TRAINING LOOP ---
print(f" Starting Training. Max: {MAX_EPOCHS} Eps. Patience: {PATIENCE_TRIGGER}")

try:
    for ep in range(start_epoch, MAX_EPOCHS):
        model.train()
        train_loss = 0

        # Training Step
        for x, y in tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds = model(x)
            loss = criterion(preds, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation Step
        eval_model = swa_model if swa_active else model
        eval_model.eval()
        val_loss = 0
        val_iou_sum = 0

        with torch.no_grad():
            for x, y in loaders['val']:
                x, y = x.to(device), y.to(device)

                if swa_active: preds = eval_model(x)
                else: preds = model(x)

                val_loss += criterion(preds, y).item()

                probs = torch.sigmoid(preds)
                # Helper metric calc
                pred_mask = (probs > 0.5).float()
                inter = (pred_mask * y).sum()
                union = pred_mask.sum() + y.sum() - inter
                iou = (inter + 1e-6) / (union + 1e-6)
                val_iou_sum += iou.item()

        # Stats
        avg_t = train_loss / len(loaders['train'])
        avg_v = val_loss / len(loaders['val'])
        avg_iou = val_iou_sum / len(loaders['val'])

        # Log to History
        history['epoch'].append(ep+1)
        history['train_loss'].append(avg_t)
        history['val_loss'].append(avg_v)
        history['val_iou'].append(avg_iou)

        # Append to CSV
        with open(LOG_PATH, 'a', newline='') as f:
            writer = csv.writer(f)
            if ep == 0 and not os.path.exists(LOG_PATH):
                writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
            writer.writerow([ep+1, avg_t, avg_v, avg_iou])

        # Status & Logic
        status = ""

        if swa_active:
            swa_model.update_parameters(model)
            swa_scheduler.step()
            swa_epoch_counter += 1
            status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"

            if swa_epoch_counter >= SWA_DURATION:
                print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
                print("SWA Complete. Saving Final Model.")
                update_bn(loaders['train'], swa_model, device=device)
                torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
                break
        else:
            scheduler.step()
            if avg_iou > best_iou:
                best_iou = avg_iou
                patience_counter = 0
                torch.save(model.state_dict(), BEST_MODEL_PATH)
                status = f" Best IoU!"
            else:
                patience_counter += 1
                status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"

            if patience_counter >= PATIENCE_TRIGGER:
                print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
                swa_active = True
                swa_epoch_counter = 0
                model.load_state_dict(torch.load(BEST_MODEL_PATH))

        print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")

        # Save Checkpoint
        torch.save({
            'epoch': ep,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'swa_state_dict': swa_model.state_dict() if swa_active else None,
            'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
            'best_iou': best_iou,
            'patience_counter': patience_counter,
            'swa_active': swa_active,
            'swa_epoch_counter': swa_epoch_counter
        }, CHECKPOINT_PATH)

except KeyboardInterrupt:
    print("Training Interrupted. Checkpoint Saved.")

# Plot Results
if len(history['train_loss']) > 0:
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.legend()
    plt.title("Training Curves")
    plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
    plt.show()
Initializing Datasets...
 Loaders Ready.
Device: cuda
 Starting Training. Max: 5000 Eps. Patience: 100
Ep 1:   0%|          | 0/1 [00:00<?, ?it/s]
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
/tmp/ipython-input-581224262.py in <cell line: 0>()
    129             x, y = x.to(device), y.to(device)
    130             optimizer.zero_grad()
--> 131             preds = model(x)
    132             loss = criterion(preds, y)
    133             loss.backward()

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/tmp/ipython-input-3544700008.py in forward(self, x)
    124     def forward(self, x):
    125         if x.ndim == 4: B, _, H, W = x.shape; x = x.reshape(B, self.num_frames, self.chans, H, W)
--> 126         features = self.encoder(x)
    127         x = features.mean(dim=1).permute(0, 3, 1, 2)
    128         x = self.conv1(self.up1(x))

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/tmp/ipython-input-3544700008.py in forward(self, x)
    103         x = x + self.pos_embed_spatial + self.pos_embed_temporal
    104         x = x.reshape(B, T * x.shape[2], -1)
--> 105         x = self.norm(self.blocks(x))
    106         return x.reshape(B, T, 14, 14, -1)
    107 

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py in forward(self, src, mask, src_key_padding_mask, is_causal)
    522 
    523         for mod in self.layers:
--> 524             output = mod(
    525                 output,
    526                 src_mask=mask,

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py in forward(self, src, src_mask, src_key_padding_mask, is_causal)
    935                 + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)
    936             )
--> 937             x = self.norm2(x + self._ff_block(x))
    938 
    939         return x

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py in _ff_block(self, x)
    960     # feed forward block
    961     def _ff_block(self, x: Tensor) -> Tensor:
--> 962         x = self.linear2(self.dropout(self.activation(self.linear1(x))))
    963         return self.dropout2(x)
    964 

OutOfMemoryError: CUDA out of memory. Tried to allocate 250.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 180.12 MiB is free. Process 4969 has 14.56 GiB memory in use. Of the allocated memory 14.08 GiB is allocated by PyTorch, and 367.47 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
In [7]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast  # <--- NEW: For Mixed Precision
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc

# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")

# --- MEMORY OPTIMIZATION SETTINGS ---
PHYSICAL_BATCH_SIZE = 4   # Small enough to fit in GPU
ACCUM_STEPS = 4           # 4 * 4 = 16 (Effective Batch Size)
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50

# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

# --- 1. DATASET DEFINITION ---
class SatMAEDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.x_data = np.load(x_path, mmap_mode='r')
        self.y_data = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.x_data)
    def __getitem__(self, idx):
        x_np = self.x_data[idx]
        y_np = self.y_data[idx]
        x_float = x_np.astype(np.float32)
        y_float = y_np.astype(np.float32)
        x_float[x_np == 255] = 0.0
        y_float[y_np == 255] = 0.0
        x_float /= 250.0
        return torch.from_numpy(x_float), torch.from_numpy(y_float)

# --- 2. SETUP LOADERS ---
try:
    print("Initializing Datasets...")
    train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
    val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))

    loaders = {
        'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True),
        'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    }
    print(" Loaders Ready.")
except Exception as e:
    print(f" Error loading data: {e}")
    raise e

# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Clear Cache before starting
torch.cuda.empty_cache()
gc.collect()

try:
    model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
    criterion = CompoundLoss()
except NameError:
    raise NameError(" Error: 'SatMAESegmentation' not defined. Run Cell 2 first.")

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler() # <--- NEW: Handles FP16 gradients

# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)

# --- 4. RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}

if os.path.exists(CHECKPOINT_PATH):
    print(" Found checkpoint. Resuming...")
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_iou = ckpt['best_iou']
    patience_counter = ckpt['patience_counter']
    swa_active = ckpt['swa_active']
    swa_epoch_counter = ckpt['swa_epoch_counter']
    if swa_active:
        swa_model.load_state_dict(ckpt['swa_state_dict'])
        swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])

    if os.path.exists(LOG_PATH):
        df = pd.read_csv(LOG_PATH)
        history['epoch'] = df['epoch'].tolist()
        history['train_loss'] = df['train_loss'].tolist()
        history['val_loss'] = df['val_loss'].tolist()
        history['val_iou'] = df['val_iou'].tolist()
    print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")

# --- 5. TRAINING LOOP (OPTIMIZED) ---
print(f"Starting Training. Max: {MAX_EPOCHS} Eps. Batch: {PHYSICAL_BATCH_SIZE} (Accum to 16)")

try:
    for ep in range(start_epoch, MAX_EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad() # Initialize gradients

        # Training Step
        pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
        for i, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)

            # Mixed Precision Forward
            with autocast(enabled=True):
                preds = model(x)
                loss = criterion(preds, y)
                loss = loss / ACCUM_STEPS # Normalize loss

            # Backward
            scaler.scale(loss).backward()

            # Update weights every ACCUM_STEPS
            if (i + 1) % ACCUM_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_loss += loss.item() * ACCUM_STEPS # Scale back up for logging

        # Validation Step
        eval_model = swa_model if swa_active else model
        eval_model.eval()
        val_loss = 0
        val_iou_sum = 0

        with torch.no_grad():
            for x, y in loaders['val']:
                x, y = x.to(device), y.to(device)

                with autocast(enabled=True):
                    if swa_active: preds = eval_model(x)
                    else: preds = model(x)
                    val_loss += criterion(preds, y).item()

                probs = torch.sigmoid(preds)
                pred_mask = (probs > 0.5).float()
                inter = (pred_mask * y).sum()
                union = pred_mask.sum() + y.sum() - inter
                iou = (inter + 1e-6) / (union + 1e-6)
                val_iou_sum += iou.item()

        # Stats
        avg_t = train_loss / len(loaders['train'])
        avg_v = val_loss / len(loaders['val'])
        avg_iou = val_iou_sum / len(loaders['val'])

        # Append to CSV
        with open(LOG_PATH, 'a', newline='') as f:
            writer = csv.writer(f)
            if ep == 0 and not os.path.exists(LOG_PATH):
                writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
            writer.writerow([ep+1, avg_t, avg_v, avg_iou])

        # Status & Logic
        status = ""
        if swa_active:
            swa_model.update_parameters(model)
            swa_scheduler.step()
            swa_epoch_counter += 1
            status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"
            if swa_epoch_counter >= SWA_DURATION:
                print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
                print(" SWA Complete. Saving Final Model.")
                update_bn(loaders['train'], swa_model, device=device)
                torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
                break
        else:
            scheduler.step()
            if avg_iou > best_iou:
                best_iou = avg_iou
                patience_counter = 0
                torch.save(model.state_dict(), BEST_MODEL_PATH)
                status = f" Best IoU!"
            else:
                patience_counter += 1
                status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"
            if patience_counter >= PATIENCE_TRIGGER:
                print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
                swa_active = True
                swa_epoch_counter = 0
                model.load_state_dict(torch.load(BEST_MODEL_PATH))

        print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")

        # Save Checkpoint
        torch.save({
            'epoch': ep, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'swa_state_dict': swa_model.state_dict() if swa_active else None,
            'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
            'best_iou': best_iou, 'patience_counter': patience_counter,
            'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
        }, CHECKPOINT_PATH)

except KeyboardInterrupt:
    print("Training Interrupted. Checkpoint Saved.")

# Plot Results
if os.path.exists(LOG_PATH):
    df = pd.read_csv(LOG_PATH)
    if len(df) > 0:
        plt.figure(figsize=(10, 5))
        plt.plot(df['train_loss'], label='Train Loss')
        plt.plot(df['val_loss'], label='Val Loss')
        plt.legend()
        plt.title("Training Curves")
        plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
        plt.show()
Exception ignored in: <function _ConnectionBase.__del__ at 0x78fc142d4540>
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 133, in __del__
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 377, in _close
OSError: [Errno 9] Bad file descriptor
Initializing Datasets...
 Loaders Ready.
Device: cuda
/tmp/ipython-input-2721185616.py:77: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler() # <--- NEW: Handles FP16 gradients
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
/tmp/ipython-input-2721185616.py in <cell line: 0>()
     78 
     79 # SWA Setup
---> 80 swa_model = AveragedModel(model)
     81 swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
     82 

/usr/local/lib/python3.12/dist-packages/torch/optim/swa_utils.py in __init__(self, model, device, avg_fn, multi_avg_fn, use_buffers)
    232             "Only one of avg_fn and multi_avg_fn should be provided"
    233         )
--> 234         self.module = deepcopy(model)
    235         if device is not None:
    236             self.module = self.module.to(device)

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    160                     y = x
    161                 else:
--> 162                     y = _reconstruct(x, memo, *rv)
    163 
    164     # If is its own copy, don't memoize.

/usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    257     if state is not None:
    258         if deep:
--> 259             state = deepcopy(state, memo)
    260         if hasattr(y, '__setstate__'):
    261             y.__setstate__(state)

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    160                     y = x
    161                 else:
--> 162                     y = _reconstruct(x, memo, *rv)
    163 
    164     # If is its own copy, don't memoize.

/usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    257     if state is not None:
    258         if deep:
--> 259             state = deepcopy(state, memo)
    260         if hasattr(y, '__setstate__'):
    261             y.__setstate__(state)

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    160                     y = x
    161                 else:
--> 162                     y = _reconstruct(x, memo, *rv)
    163 
    164     # If is its own copy, don't memoize.

/usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    257     if state is not None:
    258         if deep:
--> 259             state = deepcopy(state, memo)
    260         if hasattr(y, '__setstate__'):
    261             y.__setstate__(state)

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    160                     y = x
    161                 else:
--> 162                     y = _reconstruct(x, memo, *rv)
    163 
    164     # If is its own copy, don't memoize.

/usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    257     if state is not None:
    258         if deep:
--> 259             state = deepcopy(state, memo)
    260         if hasattr(y, '__setstate__'):
    261             y.__setstate__(state)

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    160                     y = x
    161                 else:
--> 162                     y = _reconstruct(x, memo, *rv)
    163 
    164     # If is its own copy, don't memoize.

/usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    257     if state is not None:
    258         if deep:
--> 259             state = deepcopy(state, memo)
    260         if hasattr(y, '__setstate__'):
    261             y.__setstate__(state)

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    160                     y = x
    161                 else:
--> 162                     y = _reconstruct(x, memo, *rv)
    163 
    164     # If is its own copy, don't memoize.

/usr/lib/python3.12/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    257     if state is not None:
    258         if deep:
--> 259             state = deepcopy(state, memo)
    260         if hasattr(y, '__setstate__'):
    261             y.__setstate__(state)

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    134     copier = _deepcopy_dispatch.get(cls)
    135     if copier is not None:
--> 136         y = copier(x, memo)
    137     else:
    138         if issubclass(cls, type):

/usr/lib/python3.12/copy.py in _deepcopy_dict(x, memo, deepcopy)
    219     memo[id(x)] = y
    220     for key, value in x.items():
--> 221         y[deepcopy(key, memo)] = deepcopy(value, memo)
    222     return y
    223 d[dict] = _deepcopy_dict

/usr/lib/python3.12/copy.py in deepcopy(x, memo, _nil)
    141             copier = getattr(x, "__deepcopy__", None)
    142             if copier is not None:
--> 143                 y = copier(memo)
    144             else:
    145                 reductor = dispatch_table.get(cls)

/usr/local/lib/python3.12/dist-packages/torch/nn/parameter.py in __deepcopy__(self, memo)
     77         else:
     78             result = type(self)(
---> 79                 self.data.clone(memory_format=torch.preserve_format), self.requires_grad
     80             )
     81             memo[id(self)] = result

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 8.12 MiB is free. Process 4969 has 14.73 GiB memory in use. Of the allocated memory 14.28 GiB is allocated by PyTorch, and 331.81 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
In [8]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc

# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")

# --- ULTRA-LOW MEMORY SETTINGS ---
PHYSICAL_BATCH_SIZE = 2   # <--- REDUCED to 2 to fix OOM
ACCUM_STEPS = 8           # <--- INCREASED to 8 (2 * 8 = 16 Effective Batch)
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50

# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

# --- 1. DATASET DEFINITION ---
class SatMAEDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.x_data = np.load(x_path, mmap_mode='r')
        self.y_data = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.x_data)
    def __getitem__(self, idx):
        x_np = self.x_data[idx]
        y_np = self.y_data[idx]
        x_float = x_np.astype(np.float32)
        y_float = y_np.astype(np.float32)
        x_float[x_np == 255] = 0.0
        y_float[y_np == 255] = 0.0
        x_float /= 250.0
        return torch.from_numpy(x_float), torch.from_numpy(y_float)

# --- 2. SETUP LOADERS ---
try:
    print("Initializing Datasets...")
    train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
    val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))

    loaders = {
        'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True),
        'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    }
    print(" Loaders Ready.")
except Exception as e:
    print(f" Error loading data: {e}")
    raise e

# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Aggressive Memory Cleanup
torch.cuda.empty_cache()
gc.collect()

try:
    model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
    criterion = CompoundLoss()
except NameError:
    raise NameError(" Error: 'SatMAESegmentation' not defined. Run Cell 2 first.")

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler()

# SWA Setup (Initialize on CPU first to save VRAM, move to GPU only when needed if possible,
# but for simplicity we keep on GPU. With batch=2 it should fit.)
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)

# --- 4. RESUME LOGIC ---
start_epoch = 0
best_iou = 0.0
patience_counter = 0
swa_active = False
swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}

if os.path.exists(CHECKPOINT_PATH):
    print(" Found checkpoint. Resuming...")
    try:
        ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        start_epoch = ckpt['epoch'] + 1
        best_iou = ckpt['best_iou']
        patience_counter = ckpt['patience_counter']
        swa_active = ckpt['swa_active']
        swa_epoch_counter = ckpt['swa_epoch_counter']
        if swa_active:
            swa_model.load_state_dict(ckpt['swa_state_dict'])
            swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])

        if os.path.exists(LOG_PATH):
            df = pd.read_csv(LOG_PATH)
            history['epoch'] = df['epoch'].tolist()
            history['train_loss'] = df['train_loss'].tolist()
            history['val_loss'] = df['val_loss'].tolist()
            history['val_iou'] = df['val_iou'].tolist()
        print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
    except Exception as e:
        print(f" Warning: Checkpoint corrupt or mismatch. Starting fresh. ({e})")

# --- 5. TRAINING LOOP ---
print(f" Starting Training. Max: {MAX_EPOCHS} Eps. Batch: {PHYSICAL_BATCH_SIZE} (Accum to 16)")

try:
    for ep in range(start_epoch, MAX_EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad(set_to_none=True) # Saves Memory

        # Training Step
        pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
        for i, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)

            # Mixed Precision Forward
            with autocast(enabled=True):
                preds = model(x)
                loss = criterion(preds, y)
                loss = loss / ACCUM_STEPS

            # Backward
            scaler.scale(loss).backward()

            # Update weights every ACCUM_STEPS
            if (i + 1) % ACCUM_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            train_loss += loss.item() * ACCUM_STEPS

        # Validation Step
        eval_model = swa_model if swa_active else model
        eval_model.eval()
        val_loss = 0
        val_iou_sum = 0

        # Explicit Memory Cleanup before Validation
        torch.cuda.empty_cache()

        with torch.no_grad():
            for x, y in loaders['val']:
                x, y = x.to(device), y.to(device)

                with autocast(enabled=True):
                    if swa_active: preds = eval_model(x)
                    else: preds = model(x)
                    val_loss += criterion(preds, y).item()

                probs = torch.sigmoid(preds)
                pred_mask = (probs > 0.5).float()
                inter = (pred_mask * y).sum()
                union = pred_mask.sum() + y.sum() - inter
                iou = (inter + 1e-6) / (union + 1e-6)
                val_iou_sum += iou.item()

        # Stats
        avg_t = train_loss / len(loaders['train'])
        avg_v = val_loss / len(loaders['val'])
        avg_iou = val_iou_sum / len(loaders['val'])

        # Append to CSV
        with open(LOG_PATH, 'a', newline='') as f:
            writer = csv.writer(f)
            if ep == 0 and not os.path.exists(LOG_PATH):
                writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
            writer.writerow([ep+1, avg_t, avg_v, avg_iou])

        # Status & Logic
        status = ""
        if swa_active:
            swa_model.update_parameters(model)
            swa_scheduler.step()
            swa_epoch_counter += 1
            status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"
            if swa_epoch_counter >= SWA_DURATION:
                print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
                print("SWA Complete. Saving Final Model.")
                update_bn(loaders['train'], swa_model, device=device)
                torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
                break
        else:
            scheduler.step()
            if avg_iou > best_iou:
                best_iou = avg_iou
                patience_counter = 0
                torch.save(model.state_dict(), BEST_MODEL_PATH)
                status = f" Best IoU!"
            else:
                patience_counter += 1
                status = f"No Improv ({patience_counter}/{PATIENCE_TRIGGER})"
            if patience_counter >= PATIENCE_TRIGGER:
                print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
                swa_active = True
                swa_epoch_counter = 0
                model.load_state_dict(torch.load(BEST_MODEL_PATH))

        print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")

        # Save Checkpoint
        torch.save({
            'epoch': ep, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'swa_state_dict': swa_model.state_dict() if swa_active else None,
            'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
            'best_iou': best_iou, 'patience_counter': patience_counter,
            'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
        }, CHECKPOINT_PATH)

except KeyboardInterrupt:
    print("Training Interrupted. Checkpoint Saved.")

# Plot Results
if os.path.exists(LOG_PATH):
    df = pd.read_csv(LOG_PATH)
    if len(df) > 0:
        plt.figure(figsize=(10, 5))
        plt.plot(df['train_loss'], label='Train Loss')
        plt.plot(df['val_loss'], label='Val Loss')
        plt.legend()
        plt.title("Training Curves")
        plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
        plt.show()
Initializing Datasets...
 Loaders Ready.
Device: cuda
 Starting Training. Max: 5000 Eps. Batch: 2 (Accum to 16)
/tmp/ipython-input-4233285996.py:77: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
Ep 1:   0%|          | 0/5 [00:00<?, ?it/s]
/tmp/ipython-input-4233285996.py:133: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(enabled=True):
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipython-input-4233285996.py in <cell line: 0>()
    137 
    138             # Backward
--> 139             scaler.scale(loss).backward()
    140 
    141             # Update weights every ACCUM_STEPS

/usr/local/lib/python3.12/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    623                 inputs=inputs,
    624             )
--> 625         torch.autograd.backward(
    626             self, gradient, retain_graph, create_graph, inputs=inputs
    627         )

/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    352     # some Python versions print out the first line of a multi-line function
    353     # calls in the traceback and some print out the last line
--> 354     _engine_run_backward(
    355         tensors,
    356         grad_tensors_,

/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs)
    839         unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    840     try:
--> 841         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    842             t_outputs, *args, **kwargs
    843         )  # Calls into the C++ engine to run the backward pass

RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
In [10]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc

# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")

# --- ULTRA-LOW MEMORY SETTINGS ---
PHYSICAL_BATCH_SIZE = 1   # Process 1 image at a time
ACCUM_STEPS = 16          # Update weights every 16 images (Effective Batch = 16)
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50

# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

# --- 1. HELPERS & DATASET ---
def compute_metrics(pred_probs, targets, threshold=0.5):
    """Calculates IoU and Dice Score"""
    pred_mask = (pred_probs > threshold).float()
    intersection = (pred_mask * targets).sum()
    union = pred_mask.sum() + targets.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
    return iou.item(), dice.item()

class SatMAEDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.x_data = np.load(x_path, mmap_mode='r')
        self.y_data = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.x_data)
    def __getitem__(self, idx):
        x_np = self.x_data[idx]
        y_np = self.y_data[idx]
        x_float = x_np.astype(np.float32)
        y_float = y_np.astype(np.float32)
        x_float[x_np == 255] = 0.0 # Handle NoData
        y_float[y_np == 255] = 0.0
        x_float /= 250.0           # Scale 0-250 -> 0-1
        return torch.from_numpy(x_float), torch.from_numpy(y_float)

# --- 2. SETUP LOADERS ---
try:
    print("Initializing Datasets...")
    train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
    val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
    loaders = {
        'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True),
        'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    }
except Exception as e:
    raise RuntimeError(f" Error loading data. Did you run Cell 1? ({e})")

# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Aggressive Cleanup
torch.cuda.empty_cache()
gc.collect()

try:
    # Ensure Cell 2 was run
    model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
    criterion = CompoundLoss()
except NameError:
    raise NameError(" Error: Model classes not found. Please RUN CELL 2 (Architecture) first.")

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler() # Mixed Precision Scaler

# SWA Setup
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)

# --- 4. RESUME LOGIC ---
start_epoch = 0; best_iou = 0.0; patience_counter = 0; swa_active = False; swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}

if os.path.exists(CHECKPOINT_PATH):
    print(" Resuming from Checkpoint...")
    try:
        ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        start_epoch = ckpt['epoch'] + 1
        best_iou = ckpt['best_iou']
        patience_counter = ckpt['patience_counter']
        swa_active = ckpt['swa_active']
        swa_epoch_counter = ckpt['swa_epoch_counter']

        if swa_active:
            swa_model.load_state_dict(ckpt['swa_state_dict'])
            swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])

        if os.path.exists(LOG_PATH):
            df = pd.read_csv(LOG_PATH)
            history['epoch'] = df['epoch'].tolist()
            history['train_loss'] = df['train_loss'].tolist()
            history['val_loss'] = df['val_loss'].tolist()
            history['val_iou'] = df['val_iou'].tolist()
        print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
    except Exception as e:
        print(f" Checkpoint load failed ({e}). Starting fresh.")

# --- 5. TRAINING LOOP (The Engine) ---
print(f" Training Start. Batch: {PHYSICAL_BATCH_SIZE} | Accum: {ACCUM_STEPS}")
print("-" * 50)

try:
    for ep in range(start_epoch, MAX_EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad(set_to_none=True)

        pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
        for i, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)

            # Forward (Mixed Precision)
            with autocast(enabled=True):
                preds = model(x)
                loss = criterion(preds, y)
                loss = loss / ACCUM_STEPS # Normalize loss

            # Backward
            scaler.scale(loss).backward()

            # Step (Gradient Accumulation)
            if (i + 1) % ACCUM_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            train_loss += loss.item() * ACCUM_STEPS

        # Cleanup before Validation
        del x, y, preds, loss
        torch.cuda.empty_cache()

        # Validation
        eval_model = swa_model if swa_active else model
        eval_model.eval()
        val_loss = 0
        val_iou_sum = 0

        with torch.no_grad():
            for x, y in loaders['val']:
                x, y = x.to(device), y.to(device)

                with autocast(enabled=True):
                    if swa_active: preds = eval_model(x)
                    else: preds = model(x)
                    val_loss += criterion(preds, y).item()

                probs = torch.sigmoid(preds)
                iou, _ = compute_metrics(probs, y)
                val_iou_sum += iou

        # Stats
        avg_t = train_loss / len(loaders['train'])
        avg_v = val_loss / len(loaders['val'])
        avg_iou = val_iou_sum / len(loaders['val'])

        # Logging
        with open(LOG_PATH, 'a', newline='') as f:
            writer = csv.writer(f)
            if ep == 0 and not os.path.exists(LOG_PATH):
                writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
            writer.writerow([ep+1, avg_t, avg_v, avg_iou])

        # SWA & Patience Logic [Visual Flowchart Reference below]
        #
        status = ""
        if swa_active:
            swa_model.update_parameters(model)
            swa_scheduler.step()
            swa_epoch_counter += 1
            status = f"SWA Mode ({swa_epoch_counter}/{SWA_DURATION})"

            if swa_epoch_counter >= SWA_DURATION:
                print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")
                print(" SWA Complete. Saving Final Model.")
                update_bn(loaders['train'], swa_model, device=device)
                torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
                break
        else:
            scheduler.step()
            if avg_iou > best_iou:
                best_iou = avg_iou
                patience_counter = 0
                torch.save(model.state_dict(), BEST_MODEL_PATH)
                status = f" Best IoU!"
            else:
                patience_counter += 1
                status = f"Wait ({patience_counter})"

            if patience_counter >= PATIENCE_TRIGGER:
                print(f" Patience Limit Reached. Triggering SWA for {SWA_DURATION} epochs...")
                swa_active = True
                swa_epoch_counter = 0
                # Reset to best model before starting SWA
                model.load_state_dict(torch.load(BEST_MODEL_PATH))

        print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")

        # Rolling Checkpoint
        torch.save({
            'epoch': ep, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'swa_state_dict': swa_model.state_dict() if swa_active else None,
            'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
            'best_iou': best_iou, 'patience_counter': patience_counter,
            'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
        }, CHECKPOINT_PATH)

except KeyboardInterrupt:
    print(" Training Interrupted. Checkpoint Saved.")

# Plot Results
if os.path.exists(LOG_PATH):
    df = pd.read_csv(LOG_PATH)
    if len(df) > 0:
        plt.figure(figsize=(10, 5))
        plt.plot(df['train_loss'], label='Train Loss')
        plt.plot(df['val_loss'], label='Val Loss')
        plt.legend()
        plt.title("Training Curves")
        plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
        plt.show()
Initializing Datasets...
Device: cuda
 Training Start. Batch: 1 | Accum: 16
--------------------------------------------------
/tmp/ipython-input-3170614832.py:84: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler() # Mixed Precision Scaler
Ep 1:   0%|          | 0/9 [00:00<?, ?it/s]
/tmp/ipython-input-3170614832.py:136: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(enabled=True):
/tmp/ipython-input-3170614832.py:166: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(enabled=True):
Ep 1 | T: 1.4361 | V: 0.9212 | IoU: 0.0000 | 🏆 Best IoU!
Ep 2:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 2 | T: 1.4359 | V: 0.9345 | IoU: 0.0245 | 🏆 Best IoU!
Ep 3:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 3 | T: 1.4360 | V: 0.9901 | IoU: 0.3243 | 🏆 Best IoU!
Ep 4:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 4 | T: 1.4360 | V: 1.1157 | IoU: 0.3136 | Wait (1)
Ep 5:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 5 | T: 1.4363 | V: 1.2833 | IoU: 0.2935 | Wait (2)
Ep 6:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 6 | T: 1.4359 | V: 1.4157 | IoU: 0.2826 | Wait (3)
Ep 7:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 7 | T: 1.4355 | V: 1.4877 | IoU: 0.2782 | Wait (4)
Ep 8:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 8 | T: 1.4359 | V: 1.5201 | IoU: 0.2767 | Wait (5)
Ep 9:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 9 | T: 1.4362 | V: 1.5329 | IoU: 0.2761 | Wait (6)
Ep 10:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 10 | T: 1.4365 | V: 1.5380 | IoU: 0.2762 | Wait (7)
Ep 11:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 11 | T: 1.4367 | V: 1.5406 | IoU: 0.2762 | Wait (8)
Ep 12:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 12 | T: 1.4358 | V: 1.5412 | IoU: 0.2761 | Wait (9)
Ep 13:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 13 | T: 1.4361 | V: 1.5414 | IoU: 0.2760 | Wait (10)
Ep 14:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 14 | T: 1.4366 | V: 1.5416 | IoU: 0.2760 | Wait (11)
Ep 15:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 15 | T: 1.4357 | V: 1.5416 | IoU: 0.2759 | Wait (12)
Ep 16:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 16 | T: 1.4358 | V: 1.5414 | IoU: 0.2760 | Wait (13)
Ep 17:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 17 | T: 1.4363 | V: 1.5413 | IoU: 0.2760 | Wait (14)
Ep 18:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 18 | T: 1.4356 | V: 1.5418 | IoU: 0.2761 | Wait (15)
Ep 19:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 19 | T: 1.4362 | V: 1.5419 | IoU: 0.2759 | Wait (16)
Ep 20:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 20 | T: 1.4363 | V: 1.5418 | IoU: 0.2759 | Wait (17)
Ep 21:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 21 | T: 1.4360 | V: 1.5415 | IoU: 0.2760 | Wait (18)
Ep 22:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 22 | T: 1.4365 | V: 1.5417 | IoU: 0.2761 | Wait (19)
Ep 23:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 23 | T: 1.4361 | V: 1.5414 | IoU: 0.2758 | Wait (20)
Ep 24:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 24 | T: 1.4361 | V: 1.5415 | IoU: 0.2760 | Wait (21)
Ep 25:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 25 | T: 1.4361 | V: 1.5415 | IoU: 0.2759 | Wait (22)
Ep 26:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 26 | T: 1.4359 | V: 1.5415 | IoU: 0.2761 | Wait (23)
Ep 27:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 27 | T: 1.4351 | V: 1.5418 | IoU: 0.2761 | Wait (24)
Ep 28:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 28 | T: 1.4359 | V: 1.5415 | IoU: 0.2759 | Wait (25)
Ep 29:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 29 | T: 1.4355 | V: 1.5417 | IoU: 0.2761 | Wait (26)
Ep 30:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 30 | T: 1.4361 | V: 1.5414 | IoU: 0.2759 | Wait (27)
Ep 31:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 31 | T: 1.4359 | V: 1.5413 | IoU: 0.2762 | Wait (28)
 Training Interrupted. Checkpoint Saved.
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key)
   3804         try:
-> 3805             return self._engine.get_loc(casted_key)
   3806         except KeyError as err:

index.pyx in pandas._libs.index.IndexEngine.get_loc()

index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'train_loss'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
/tmp/ipython-input-3170614832.py in <cell line: 0>()
    239     if len(df) > 0:
    240         plt.figure(figsize=(10, 5))
--> 241         plt.plot(df['train_loss'], label='Train Loss')
    242         plt.plot(df['val_loss'], label='Val Loss')
    243         plt.legend()

/usr/local/lib/python3.12/dist-packages/pandas/core/frame.py in __getitem__(self, key)
   4100             if self.columns.nlevels > 1:
   4101                 return self._getitem_multilevel(key)
-> 4102             indexer = self.columns.get_loc(key)
   4103             if is_integer(indexer):
   4104                 indexer = [indexer]

/usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key)
   3810             ):
   3811                 raise InvalidIndexError(key)
-> 3812             raise KeyError(key) from err
   3813         except TypeError:
   3814             # If we have a listlike key, _check_indexing_error will raise

KeyError: 'train_loss'
<Figure size 1000x500 with 0 Axes>
In [11]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.cuda.amp import GradScaler, autocast
import os
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc

# --- CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint_latest.pth")
BEST_MODEL_PATH = os.path.join(SAVE_DIR, "best_model.pth")
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")

# --- ECO-MODE SETTINGS (CRASH PREVENTION) ---
PHYSICAL_BATCH_SIZE = 1
ACCUM_STEPS = 16
MAX_EPOCHS = 5000
PATIENCE_TRIGGER = 100
SWA_DURATION = 50
NUM_WORKERS = 0           # <--- CHANGED: 0 saves massive RAM (no subprocesses)
PIN_MEMORY = False        # <--- CHANGED: False saves RAM (no buffer reservation)

# Ensure Directory
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

# --- 1. HELPERS & DATASET ---
def compute_metrics(pred_probs, targets, threshold=0.5):
    pred_mask = (pred_probs > threshold).float()
    intersection = (pred_mask * targets).sum()
    union = pred_mask.sum() + targets.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (pred_mask.sum() + targets.sum() + 1e-6)
    return iou.item(), dice.item()

class SatMAEDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.x_data = np.load(x_path, mmap_mode='r')
        self.y_data = np.load(y_path, mmap_mode='r')
    def __len__(self): return len(self.x_data)
    def __getitem__(self, idx):
        x_np = self.x_data[idx]
        y_np = self.y_data[idx]
        x_float = x_np.astype(np.float32)
        y_float = y_np.astype(np.float32)
        x_float[x_np == 255] = 0.0
        y_float[y_np == 255] = 0.0
        x_float /= 250.0
        return torch.from_numpy(x_float), torch.from_numpy(y_float)

# --- 2. SETUP LOADERS (OPTIMIZED) ---
try:
    print("Initializing Datasets in ECO-MODE...")
    train_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'train_x.npy'), os.path.join(SAVE_DIR, 'train_y.npy'))
    val_ds = SatMAEDataset(os.path.join(SAVE_DIR, 'val_x.npy'), os.path.join(SAVE_DIR, 'val_y.npy'))
    loaders = {
        'train': DataLoader(train_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY),
        'val': DataLoader(val_ds, batch_size=PHYSICAL_BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    }
except Exception as e:
    raise RuntimeError(f" Error loading data: {e}")

# --- 3. MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Aggressive Cleanup
torch.cuda.empty_cache()
gc.collect()

try:
    model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
    criterion = CompoundLoss()
except NameError:
    raise NameError(" Error: Model classes not found. Please RUN CELL 2 first.")

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
scaler = GradScaler()
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)

# --- 4. RESUME LOGIC ---
start_epoch = 0; best_iou = 0.0; patience_counter = 0; swa_active = False; swa_epoch_counter = 0
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'val_iou': []}

if os.path.exists(CHECKPOINT_PATH):
    print(" Resuming from Checkpoint...")
    try:
        ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        start_epoch = ckpt['epoch'] + 1
        best_iou = ckpt['best_iou']
        patience_counter = ckpt['patience_counter']
        swa_active = ckpt['swa_active']
        swa_epoch_counter = ckpt['swa_epoch_counter']
        if swa_active:
            swa_model.load_state_dict(ckpt['swa_state_dict'])
            swa_scheduler.load_state_dict(ckpt['swa_scheduler_state_dict'])
        if os.path.exists(LOG_PATH):
            df = pd.read_csv(LOG_PATH)
            # Reload history safely
            history['epoch'] = df['epoch'].tolist()
            history['train_loss'] = df['train_loss'].tolist()
            history['val_loss'] = df['val_loss'].tolist()
            history['val_iou'] = df['val_iou'].tolist()
        print(f"Resumed at Epoch {start_epoch}. Best IoU: {best_iou:.4f}")
    except Exception as e:
        print(f"Checkpoint load failed ({e}). Starting fresh.")

# --- 5. TRAINING LOOP ---
print(f" Training (Eco-Mode). Workers: {NUM_WORKERS} | Pin Mem: {PIN_MEMORY}")

try:
    for ep in range(start_epoch, MAX_EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad(set_to_none=True)

        pbar = tqdm(loaders['train'], desc=f"Ep {ep+1}", leave=False)
        for i, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)

            with autocast(enabled=True):
                preds = model(x)
                loss = criterion(preds, y)
                loss = loss / ACCUM_STEPS

            scaler.scale(loss).backward()

            if (i + 1) % ACCUM_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            train_loss += loss.item() * ACCUM_STEPS

        del x, y, preds, loss
        torch.cuda.empty_cache()

        eval_model = swa_model if swa_active else model
        eval_model.eval()
        val_loss, val_iou_sum = 0, 0

        with torch.no_grad():
            for x, y in loaders['val']:
                x, y = x.to(device), y.to(device)
                with autocast(enabled=True):
                    if swa_active: preds = eval_model(x)
                    else: preds = model(x)
                    val_loss += criterion(preds, y).item()

                probs = torch.sigmoid(preds)
                iou, _ = compute_metrics(probs, y)
                val_iou_sum += iou

        avg_t = train_loss / len(loaders['train'])
        avg_v = val_loss / len(loaders['val'])
        avg_iou = val_iou_sum / len(loaders['val'])

        with open(LOG_PATH, 'a', newline='') as f:
            writer = csv.writer(f)
            if ep == 0 and not os.path.exists(LOG_PATH):
                writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_iou'])
            writer.writerow([ep+1, avg_t, avg_v, avg_iou])

        status = ""
        if swa_active:
            swa_model.update_parameters(model)
            swa_scheduler.step()
            swa_epoch_counter += 1
            status = f"SWA ({swa_epoch_counter}/{SWA_DURATION})"
            if swa_epoch_counter >= SWA_DURATION:
                print(f"Ep {ep+1} | IoU: {avg_iou:.4f} |  SWA DONE")
                update_bn(loaders['train'], swa_model, device=device)
                torch.save(swa_model.module.state_dict(), os.path.join(SAVE_DIR, "final_swa_model.pth"))
                break
        else:
            scheduler.step()
            if avg_iou > best_iou:
                best_iou = avg_iou
                patience_counter = 0
                torch.save(model.state_dict(), BEST_MODEL_PATH)
                status = f"Best!"
            else:
                patience_counter += 1
                status = f"Wait ({patience_counter})"
            if patience_counter >= PATIENCE_TRIGGER:
                print(f" Triggering SWA...")
                swa_active = True
                swa_epoch_counter = 0
                model.load_state_dict(torch.load(BEST_MODEL_PATH))

        print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | IoU: {avg_iou:.4f} | {status}")

        torch.save({
            'epoch': ep, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'swa_state_dict': swa_model.state_dict() if swa_active else None,
            'swa_scheduler_state_dict': swa_scheduler.state_dict() if swa_active else None,
            'best_iou': best_iou, 'patience_counter': patience_counter,
            'swa_active': swa_active, 'swa_epoch_counter': swa_epoch_counter
        }, CHECKPOINT_PATH)

        # Force garbage collection every epoch to keep RAM low
        gc.collect()

except KeyboardInterrupt:
    print(" Training Interrupted. Checkpoint Saved.")

if os.path.exists(LOG_PATH):
    df = pd.read_csv(LOG_PATH)
    if len(df) > 0:
        plt.figure(figsize=(10, 5))
        plt.plot(df['train_loss'], label='Train Loss')
        plt.plot(df['val_loss'], label='Val Loss')
        plt.legend()
        plt.title("Training Curves")
        plt.savefig(os.path.join(SAVE_DIR, "loss_curve.png"))
        plt.show()
Initializing Datasets in ECO-MODE...
Device: cuda
/tmp/ipython-input-1738852815.py:86: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = GradScaler()
 Resuming from Checkpoint...
Checkpoint load failed (PytorchStreamReader failed locating file data/101: file not found). Starting fresh.
 Training (Eco-Mode). Workers: 0 | Pin Mem: False
Ep 1:   0%|          | 0/9 [00:00<?, ?it/s]
/tmp/ipython-input-1738852815.py:133: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(enabled=True):
/tmp/ipython-input-1738852815.py:157: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast(enabled=True):
Ep 1 | T: 1.4697 | V: 0.9073 | IoU: 0.6619 | Best!
Ep 2:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 2 | T: 1.4705 | V: 0.9221 | IoU: 0.6619 | Wait (1)
Ep 3:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 3 | T: 1.4699 | V: 0.9855 | IoU: 0.6610 | Wait (2)
Ep 4:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 4 | T: 1.4708 | V: 1.1187 | IoU: 0.5708 | Wait (3)
Ep 5:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 5 | T: 1.4704 | V: 1.2968 | IoU: 0.4553 | Wait (4)
Ep 6:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 6 | T: 1.4698 | V: 1.4386 | IoU: 0.4018 | Wait (5)
Ep 7:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 7 | T: 1.4707 | V: 1.5153 | IoU: 0.3800 | Wait (6)
Ep 8:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 8 | T: 1.4707 | V: 1.5490 | IoU: 0.3717 | Wait (7)
Ep 9:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 9 | T: 1.4699 | V: 1.5634 | IoU: 0.3686 | Wait (8)
Ep 10:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 10 | T: 1.4711 | V: 1.5688 | IoU: 0.3673 | Wait (9)
Ep 11:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 11 | T: 1.4708 | V: 1.5707 | IoU: 0.3668 | Wait (10)
Ep 12:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 12 | T: 1.4709 | V: 1.5720 | IoU: 0.3667 | Wait (11)
Ep 13:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 13 | T: 1.4703 | V: 1.5714 | IoU: 0.3667 | Wait (12)
Ep 14:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 14 | T: 1.4704 | V: 1.5715 | IoU: 0.3668 | Wait (13)
Ep 15:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 15 | T: 1.4700 | V: 1.5719 | IoU: 0.3667 | Wait (14)
Ep 16:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 16 | T: 1.4701 | V: 1.5719 | IoU: 0.3668 | Wait (15)
Ep 17:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 17 | T: 1.4699 | V: 1.5720 | IoU: 0.3668 | Wait (16)
Ep 18:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 18 | T: 1.4705 | V: 1.5723 | IoU: 0.3667 | Wait (17)
Ep 19:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 19 | T: 1.4704 | V: 1.5726 | IoU: 0.3664 | Wait (18)
Ep 20:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 20 | T: 1.4700 | V: 1.5727 | IoU: 0.3667 | Wait (19)
Ep 21:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 21 | T: 1.4705 | V: 1.5725 | IoU: 0.3666 | Wait (20)
Ep 22:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 22 | T: 1.4707 | V: 1.5724 | IoU: 0.3665 | Wait (21)
Ep 23:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 23 | T: 1.4704 | V: 1.5722 | IoU: 0.3665 | Wait (22)
Ep 24:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 24 | T: 1.4706 | V: 1.5724 | IoU: 0.3665 | Wait (23)
Ep 25:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 25 | T: 1.4701 | V: 1.5727 | IoU: 0.3663 | Wait (24)
Ep 26:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 26 | T: 1.4708 | V: 1.5726 | IoU: 0.3666 | Wait (25)
Ep 27:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 27 | T: 1.4699 | V: 1.5723 | IoU: 0.3668 | Wait (26)
Ep 28:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 28 | T: 1.4702 | V: 1.5723 | IoU: 0.3666 | Wait (27)
Ep 29:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 29 | T: 1.4702 | V: 1.5718 | IoU: 0.3667 | Wait (28)
Ep 30:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 30 | T: 1.4707 | V: 1.5723 | IoU: 0.3665 | Wait (29)
Ep 31:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 31 | T: 1.4701 | V: 1.5722 | IoU: 0.3665 | Wait (30)
Ep 32:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 32 | T: 1.4705 | V: 1.5721 | IoU: 0.3666 | Wait (31)
Ep 33:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 33 | T: 1.4706 | V: 1.5720 | IoU: 0.3667 | Wait (32)
Ep 34:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 34 | T: 1.4712 | V: 1.5723 | IoU: 0.3664 | Wait (33)
Ep 35:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 35 | T: 1.4703 | V: 1.5722 | IoU: 0.3666 | Wait (34)
Ep 36:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 36 | T: 1.4709 | V: 1.5722 | IoU: 0.3664 | Wait (35)
Ep 37:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 37 | T: 1.4703 | V: 1.5723 | IoU: 0.3666 | Wait (36)
Ep 38:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 38 | T: 1.4707 | V: 1.5727 | IoU: 0.3662 | Wait (37)
Ep 39:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 39 | T: 1.4701 | V: 1.5724 | IoU: 0.3667 | Wait (38)
Ep 40:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 40 | T: 1.4709 | V: 1.5723 | IoU: 0.3664 | Wait (39)
Ep 41:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 41 | T: 1.4707 | V: 1.5722 | IoU: 0.3663 | Wait (40)
Ep 42:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 42 | T: 1.4706 | V: 1.5723 | IoU: 0.3660 | Wait (41)
Ep 43:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 43 | T: 1.4708 | V: 1.5723 | IoU: 0.3665 | Wait (42)
Ep 44:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 44 | T: 1.4707 | V: 1.5721 | IoU: 0.3669 | Wait (43)
Ep 45:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 45 | T: 1.4702 | V: 1.5723 | IoU: 0.3668 | Wait (44)
Ep 46:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 46 | T: 1.4699 | V: 1.5722 | IoU: 0.3669 | Wait (45)
Ep 47:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 47 | T: 1.4697 | V: 1.5723 | IoU: 0.3666 | Wait (46)
Ep 48:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 48 | T: 1.4707 | V: 1.5726 | IoU: 0.3663 | Wait (47)
Ep 49:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 49 | T: 1.4708 | V: 1.5724 | IoU: 0.3664 | Wait (48)
Ep 50:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 50 | T: 1.4703 | V: 1.5720 | IoU: 0.3665 | Wait (49)
Ep 51:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 51 | T: 1.4704 | V: 1.5722 | IoU: 0.3666 | Wait (50)
Ep 52:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 52 | T: 1.4709 | V: 1.5723 | IoU: 0.3664 | Wait (51)
Ep 53:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 53 | T: 1.4700 | V: 1.5721 | IoU: 0.3666 | Wait (52)
Ep 54:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 54 | T: 1.4702 | V: 1.5725 | IoU: 0.3667 | Wait (53)
Ep 55:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 55 | T: 1.4706 | V: 1.5729 | IoU: 0.3663 | Wait (54)
Ep 56:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 56 | T: 1.4711 | V: 1.5727 | IoU: 0.3664 | Wait (55)
Ep 57:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 57 | T: 1.4706 | V: 1.5724 | IoU: 0.3666 | Wait (56)
Ep 58:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 58 | T: 1.4705 | V: 1.5724 | IoU: 0.3665 | Wait (57)
Ep 59:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 59 | T: 1.4699 | V: 1.5725 | IoU: 0.3666 | Wait (58)
Ep 60:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 60 | T: 1.4704 | V: 1.5727 | IoU: 0.3666 | Wait (59)
Ep 61:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 61 | T: 1.4706 | V: 1.5727 | IoU: 0.3665 | Wait (60)
Ep 62:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 62 | T: 1.4705 | V: 1.5724 | IoU: 0.3665 | Wait (61)
Ep 63:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 63 | T: 1.4701 | V: 1.5721 | IoU: 0.3667 | Wait (62)
Ep 64:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 64 | T: 1.4701 | V: 1.5722 | IoU: 0.3664 | Wait (63)
Ep 65:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 65 | T: 1.4694 | V: 1.5718 | IoU: 0.3664 | Wait (64)
Ep 66:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 66 | T: 1.4703 | V: 1.5714 | IoU: 0.3668 | Wait (65)
Ep 67:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 67 | T: 1.4703 | V: 1.5717 | IoU: 0.3667 | Wait (66)
Ep 68:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 68 | T: 1.4699 | V: 1.5720 | IoU: 0.3669 | Wait (67)
Ep 69:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 69 | T: 1.4697 | V: 1.5722 | IoU: 0.3671 | Wait (68)
Ep 70:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 70 | T: 1.4698 | V: 1.5720 | IoU: 0.3666 | Wait (69)
Ep 71:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 71 | T: 1.4702 | V: 1.5722 | IoU: 0.3666 | Wait (70)
Ep 72:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 72 | T: 1.4697 | V: 1.5723 | IoU: 0.3667 | Wait (71)
Ep 73:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 73 | T: 1.4707 | V: 1.5723 | IoU: 0.3667 | Wait (72)
Ep 74:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 74 | T: 1.4695 | V: 1.5716 | IoU: 0.3668 | Wait (73)
Ep 75:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 75 | T: 1.4704 | V: 1.5723 | IoU: 0.3666 | Wait (74)
Ep 76:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 76 | T: 1.4703 | V: 1.5724 | IoU: 0.3667 | Wait (75)
Ep 77:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 77 | T: 1.4696 | V: 1.5724 | IoU: 0.3667 | Wait (76)
Ep 78:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 78 | T: 1.4712 | V: 1.5728 | IoU: 0.3666 | Wait (77)
Ep 79:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 79 | T: 1.4707 | V: 1.5722 | IoU: 0.3666 | Wait (78)
Ep 80:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 80 | T: 1.4709 | V: 1.5729 | IoU: 0.3662 | Wait (79)
Ep 81:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 81 | T: 1.4699 | V: 1.5724 | IoU: 0.3666 | Wait (80)
Ep 82:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 82 | T: 1.4708 | V: 1.5727 | IoU: 0.3663 | Wait (81)
Ep 83:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 83 | T: 1.4700 | V: 1.5724 | IoU: 0.3665 | Wait (82)
Ep 84:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 84 | T: 1.4711 | V: 1.5723 | IoU: 0.3666 | Wait (83)
Ep 85:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 85 | T: 1.4698 | V: 1.5723 | IoU: 0.3669 | Wait (84)
Ep 86:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 86 | T: 1.4704 | V: 1.5715 | IoU: 0.3668 | Wait (85)
Ep 87:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 87 | T: 1.4701 | V: 1.5721 | IoU: 0.3665 | Wait (86)
Ep 88:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 88 | T: 1.4699 | V: 1.5721 | IoU: 0.3667 | Wait (87)
Ep 89:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 89 | T: 1.4700 | V: 1.5723 | IoU: 0.3667 | Wait (88)
Ep 90:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 90 | T: 1.4698 | V: 1.5723 | IoU: 0.3667 | Wait (89)
Ep 91:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 91 | T: 1.4711 | V: 1.5723 | IoU: 0.3669 | Wait (90)
Ep 92:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 92 | T: 1.4706 | V: 1.5721 | IoU: 0.3667 | Wait (91)
Ep 93:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 93 | T: 1.4706 | V: 1.5721 | IoU: 0.3666 | Wait (92)
Ep 94:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 94 | T: 1.4704 | V: 1.5721 | IoU: 0.3666 | Wait (93)
Ep 95:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 95 | T: 1.4701 | V: 1.5722 | IoU: 0.3664 | Wait (94)
Ep 96:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 96 | T: 1.4700 | V: 1.5719 | IoU: 0.3663 | Wait (95)
Ep 97:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 97 | T: 1.4705 | V: 1.5724 | IoU: 0.3664 | Wait (96)
Ep 98:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 98 | T: 1.4709 | V: 1.5722 | IoU: 0.3665 | Wait (97)
Ep 99:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 99 | T: 1.4707 | V: 1.5728 | IoU: 0.3666 | Wait (98)
Ep 100:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 100 | T: 1.4697 | V: 1.5724 | IoU: 0.3665 | Wait (99)
Ep 101:   0%|          | 0/9 [00:00<?, ?it/s]
 Triggering SWA...
Ep 101 | T: 1.4703 | V: 1.5719 | IoU: 0.3665 | Wait (100)
Ep 102:   0%|          | 0/9 [00:00<?, ?it/s]
/usr/local/lib/python3.12/dist-packages/torch/optim/lr_scheduler.py:192: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn(
Ep 102 | T: 1.4707 | V: 0.9056 | IoU: 0.6619 | SWA (1/50)
Ep 103:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 103 | T: 1.4703 | V: 0.9221 | IoU: 0.6619 | SWA (2/50)
Ep 104:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 104 | T: 1.4706 | V: 0.9854 | IoU: 0.6610 | SWA (3/50)
Ep 105:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 105 | T: 1.4702 | V: 1.1186 | IoU: 0.5708 | SWA (4/50)
Ep 106:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 106 | T: 1.4705 | V: 1.2967 | IoU: 0.4554 | SWA (5/50)
Ep 107:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 107 | T: 1.4703 | V: 1.4384 | IoU: 0.4018 | SWA (6/50)
Ep 108:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 108 | T: 1.4704 | V: 1.5144 | IoU: 0.3801 | SWA (7/50)
Ep 109:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 109 | T: 1.4701 | V: 1.5490 | IoU: 0.3717 | SWA (8/50)
Ep 110:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 110 | T: 1.4706 | V: 1.5634 | IoU: 0.3685 | SWA (9/50)
Ep 111:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 111 | T: 1.4710 | V: 1.5687 | IoU: 0.3677 | SWA (10/50)
Ep 112:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 112 | T: 1.4701 | V: 1.5711 | IoU: 0.3671 | SWA (11/50)
Ep 113:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 113 | T: 1.4702 | V: 1.5713 | IoU: 0.3667 | SWA (12/50)
Ep 114:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 114 | T: 1.4702 | V: 1.5717 | IoU: 0.3665 | SWA (13/50)
Ep 115:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 115 | T: 1.4705 | V: 1.5720 | IoU: 0.3666 | SWA (14/50)
Ep 116:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 116 | T: 1.4701 | V: 1.5727 | IoU: 0.3663 | SWA (15/50)
Ep 117:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 117 | T: 1.4702 | V: 1.5726 | IoU: 0.3665 | SWA (16/50)
Ep 118:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 118 | T: 1.4708 | V: 1.5725 | IoU: 0.3666 | SWA (17/50)
Ep 119:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 119 | T: 1.4708 | V: 1.5724 | IoU: 0.3665 | SWA (18/50)
Ep 120:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 120 | T: 1.4698 | V: 1.5725 | IoU: 0.3667 | SWA (19/50)
Ep 121:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 121 | T: 1.4703 | V: 1.5726 | IoU: 0.3668 | SWA (20/50)
Ep 122:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 122 | T: 1.4708 | V: 1.5723 | IoU: 0.3665 | SWA (21/50)
Ep 123:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 123 | T: 1.4707 | V: 1.5725 | IoU: 0.3664 | SWA (22/50)
Ep 124:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 124 | T: 1.4710 | V: 1.5726 | IoU: 0.3665 | SWA (23/50)
Ep 125:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 125 | T: 1.4703 | V: 1.5728 | IoU: 0.3663 | SWA (24/50)
Ep 126:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 126 | T: 1.4699 | V: 1.5722 | IoU: 0.3664 | SWA (25/50)
Ep 127:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 127 | T: 1.4708 | V: 1.5720 | IoU: 0.3663 | SWA (26/50)
Ep 128:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 128 | T: 1.4702 | V: 1.5723 | IoU: 0.3665 | SWA (27/50)
Ep 129:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 129 | T: 1.4701 | V: 1.5724 | IoU: 0.3666 | SWA (28/50)
Ep 130:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 130 | T: 1.4703 | V: 1.5725 | IoU: 0.3666 | SWA (29/50)
Ep 131:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 131 | T: 1.4709 | V: 1.5725 | IoU: 0.3667 | SWA (30/50)
Ep 132:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 132 | T: 1.4706 | V: 1.5726 | IoU: 0.3666 | SWA (31/50)
Ep 133:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 133 | T: 1.4706 | V: 1.5722 | IoU: 0.3668 | SWA (32/50)
Ep 134:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 134 | T: 1.4698 | V: 1.5720 | IoU: 0.3667 | SWA (33/50)
Ep 135:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 135 | T: 1.4710 | V: 1.5723 | IoU: 0.3667 | SWA (34/50)
Ep 136:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 136 | T: 1.4711 | V: 1.5722 | IoU: 0.3665 | SWA (35/50)
Ep 137:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 137 | T: 1.4702 | V: 1.5727 | IoU: 0.3664 | SWA (36/50)
Ep 138:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 138 | T: 1.4708 | V: 1.5725 | IoU: 0.3668 | SWA (37/50)
Ep 139:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 139 | T: 1.4707 | V: 1.5724 | IoU: 0.3667 | SWA (38/50)
Ep 140:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 140 | T: 1.4703 | V: 1.5724 | IoU: 0.3666 | SWA (39/50)
Ep 141:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 141 | T: 1.4715 | V: 1.5723 | IoU: 0.3664 | SWA (40/50)
Ep 142:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 142 | T: 1.4705 | V: 1.5723 | IoU: 0.3666 | SWA (41/50)
Ep 143:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 143 | T: 1.4708 | V: 1.5724 | IoU: 0.3667 | SWA (42/50)
Ep 144:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 144 | T: 1.4701 | V: 1.5724 | IoU: 0.3665 | SWA (43/50)
Ep 145:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 145 | T: 1.4709 | V: 1.5724 | IoU: 0.3666 | SWA (44/50)
Ep 146:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 146 | T: 1.4711 | V: 1.5726 | IoU: 0.3666 | SWA (45/50)
Ep 147:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 147 | T: 1.4697 | V: 1.5723 | IoU: 0.3665 | SWA (46/50)
Ep 148:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 148 | T: 1.4704 | V: 1.5723 | IoU: 0.3667 | SWA (47/50)
Ep 149:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 149 | T: 1.4697 | V: 1.5725 | IoU: 0.3666 | SWA (48/50)
Ep 150:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 150 | T: 1.4703 | V: 1.5725 | IoU: 0.3668 | SWA (49/50)
Ep 151:   0%|          | 0/9 [00:00<?, ?it/s]
Ep 151 | IoU: 0.3666 | ✅ SWA DONE
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key)
   3804         try:
-> 3805             return self._engine.get_loc(casted_key)
   3806         except KeyError as err:

index.pyx in pandas._libs.index.IndexEngine.get_loc()

index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'train_loss'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
/tmp/ipython-input-1738852815.py in <cell line: 0>()
    223     if len(df) > 0:
    224         plt.figure(figsize=(10, 5))
--> 225         plt.plot(df['train_loss'], label='Train Loss')
    226         plt.plot(df['val_loss'], label='Val Loss')
    227         plt.legend()

/usr/local/lib/python3.12/dist-packages/pandas/core/frame.py in __getitem__(self, key)
   4100             if self.columns.nlevels > 1:
   4101                 return self._getitem_multilevel(key)
-> 4102             indexer = self.columns.get_loc(key)
   4103             if is_integer(indexer):
   4104                 indexer = [indexer]

/usr/local/lib/python3.12/dist-packages/pandas/core/indexes/base.py in get_loc(self, key)
   3810             ):
   3811                 raise InvalidIndexError(key)
-> 3812             raise KeyError(key) from err
   3813         except TypeError:
   3814             # If we have a listlike key, _check_indexing_error will raise

KeyError: 'train_loss'
<Figure size 1000x500 with 0 Axes>
In [12]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader

# --- CONFIG ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
MODEL_PATH = os.path.join(SAVE_DIR, "final_swa_model.pth") # Using the SWA model

# --- 1. LOAD MODEL ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

try:
    # Re-initialize model architecture
    model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)

    # Load Weights
    print(f"Loading weights from: {os.path.basename(MODEL_PATH)}")
    state_dict = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    print(" Model Loaded Successfully!")
except Exception as e:
    print(f" Error loading model: {e}")
    print("Make sure Cell 2 (Architecture) was run.")

# --- 2. LOAD ONE BATCH OF VAL DATA ---
# We use the existing 'val_ds' from previous cells
try:
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=True)
    images, masks = next(iter(val_loader)) # Get 1 random sample
    images, masks = images.to(device), masks.to(device)
except NameError:
    print(" Error: 'val_ds' not found. Please re-run Cell 4 dataset setup part.")

# --- 3. RUN PREDICTION ---
with torch.no_grad():
    # Forward Pass
    logits = model(images)
    probs = torch.sigmoid(logits)
    pred_mask = (probs > 0.5).float() # Threshold at 50% confidence

# --- 4. VISUALIZATION ---
# Helper to normalize RGB for display
def get_rgb(img_tensor, frame_idx=6):
    # Shape: (B, T, C, H, W) -> (10, H, W) for one frame
    # Bands: ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
    # Indices: B4(Red)=2, B3(Green)=1, B2(Blue)=0
    img = img_tensor[0, frame_idx].cpu().numpy()
    rgb = img[[2, 1, 0], :, :] # Pick RGB bands

    # Normalize for display (Sentinel-2 is roughly 0-0.3 reflectance typically, here scaled 0-1)
    rgb = np.clip(rgb * 3.0, 0, 1) # Brighten it up x3
    return np.transpose(rgb, (1, 2, 0)) # CHW -> HWC

# Prepare plots
fig, ax = plt.subplots(1, 3, figsize=(18, 6))

# A. Satellite Image (RGB from Frame 6)
ax[0].imshow(get_rgb(images))
ax[0].set_title("Sentinel-2 Input (RGB - Frame 6)")
ax[0].axis('off')

# B. Ground Truth Mask
ax[1].imshow(masks[0, 0].cpu().numpy(), cmap='gray')
ax[1].set_title("Ground Truth (Wheat Mask)")
ax[1].axis('off')

# C. Model Prediction
ax[2].imshow(pred_mask[0, 0].cpu().numpy(), cmap='jet', alpha=1.0)
ax[2].set_title("SatMAE Prediction")
ax[2].axis('off')

plt.tight_layout()
plt.show()

print(f"Sample IoU: {compute_metrics(probs, masks)[0]:.4f}")
Device: cuda
Loading weights from: final_swa_model.pth
✅ Model Loaded Successfully!
No description has been provided for this image
Sample IoU: 0.3799
In [13]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix

# --- CONFIG ---
SAVE_DIR = '/content/drive/MyDrive/SatMAE_Scratch_Results_12Frames/'
LOG_PATH = os.path.join(SAVE_DIR, "training_log.csv")
MODEL_PATH = os.path.join(SAVE_DIR, "final_swa_model.pth")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==========================================
# PART 1: PLOT TRAINING HISTORY
# ==========================================
if os.path.exists(LOG_PATH):
    try:
        df = pd.read_csv(LOG_PATH)

        # Setup Plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        # Plot Loss
        ax1.plot(df['epoch'], df['train_loss'], label='Train Loss', color='orange', linewidth=2)
        ax1.plot(df['epoch'], df['val_loss'], label='Val Loss', color='blue', linewidth=2)
        ax1.set_title("Training vs Validation Loss")
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Loss")
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Plot IoU
        ax2.plot(df['epoch'], df['val_iou'], label='Val IoU', color='green', linewidth=2)
        ax2.set_title("Validation IoU (Quality) Over Time")
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("IoU Score")
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()
        print(" Historical graphs generated from log file.")

    except Exception as e:
        print(f" Could not plot graphs: {e}")
else:
    print(" No log file found. Cannot plot history.")

# ==========================================
# PART 2: CALCULATE DEEP METRICS (F1, PRECISION, RECALL)
# ==========================================
print("\n" + "="*40)
print(" STARTING DEEP METRIC EVALUATION...")
print("="*40)

# 1. Load Model
try:
    model = SatMAESegmentation(img_size=224, patch_size=16, in_chans=10, num_frames=12).to(device)
    state_dict = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    print(" Final Model Loaded.")
except Exception as e:
    print(f" Error loading model: {e}")
    # Stop execution if model fails
    raise RuntimeError("Model needed for metrics.")

# 2. Setup Data
try:
    # Use existing validation dataset
    eval_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
except NameError:
    print(" Error: 'val_ds' not defined. Please run the Dataset Setup cell first.")

# 3. Accumulate Pixels (Global Calculation)
# We flatten all predictions and targets into one giant list for accurate global stats
all_preds = []
all_targets = []

print(f" Evaluating over {len(eval_loader)} validation samples...")

with torch.no_grad():
    for images, masks in eval_loader:
        images = images.to(device)

        # Predict
        logits = model(images)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).cpu().numpy().flatten().astype(np.uint8)

        # Ground Truth
        targets = masks.cpu().numpy().flatten().astype(np.uint8)

        all_preds.append(preds)
        all_targets.append(targets)

# Concatenate all pixels
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

# 4. Compute Metrics
# Note: Zero_division=0 handles cases where model predicts nothing (safe division)
precision = precision_score(all_targets, all_preds, zero_division=0)
recall = recall_score(all_targets, all_preds, zero_division=0) # Sensitivity
f1 = f1_score(all_targets, all_preds, zero_division=0)
accuracy = accuracy_score(all_targets, all_preds)

# IoU Calculation (Jaccard)
intersection = np.logical_and(all_targets, all_preds).sum()
union = np.logical_or(all_targets, all_preds).sum()
iou = intersection / union if union > 0 else 0.0

# Confusion Matrix Elements
tn, fp, fn, tp = confusion_matrix(all_targets, all_preds).ravel()

# ==========================================
# PART 3: FINAL REPORT
# ==========================================
print("\n" + "="*40)
print(" FINAL MODEL PERFORMANCE REPORT")
print("="*40)
print(f"IoU (Intersection over Union):  {iou:.4f}  (Main Metric)")
print(f"F1 Score (Dice Coefficient):    {f1:.4f}    (Harmonic Mean)")
print("-" * 40)
print(f"Precision (True Pos / Pred Pos):{precision:.4f}  (How trustworthy 'Wheat' is)")
print(f"Recall (True Pos / Actual Pos): {recall:.4f}    (How much Wheat we found)")
print(f"Accuracy (Pixel-wise):          {accuracy:.4f}")
print("-" * 40)
print("Confusion Matrix (Pixels):")
print(f"   True Wheat Found (TP):       {tp}")
print(f"   Background Found (TN):       {tn}")
print(f"   Wheat Missed (FN):           {fn}")
print(f"   False Alarm (FP):            {fp}")
print("="*40)
 Could not plot graphs: 'epoch'

========================================
 STARTING DEEP METRIC EVALUATION...
========================================
 Final Model Loaded.
⏳ Evaluating over 3 validation samples...

========================================
 FINAL MODEL PERFORMANCE REPORT
========================================
IoU (Intersection over Union):  0.3686  (Main Metric)
F1 Score (Dice Coefficient):    0.5386    (Harmonic Mean)
----------------------------------------
Precision (True Pos / Pred Pos):0.6618  (How trustworthy 'Wheat' is)
Recall (True Pos / Actual Pos): 0.4541    (How much Wheat we found)
Accuracy (Pixel-wise):          0.4850
----------------------------------------
Confusion Matrix (Pixels):
   True Wheat Found (TP):       45250
   Background Found (TN):       27763
   Wheat Missed (FN):           54387
   False Alarm (FP):            23128
========================================
No description has been provided for this image